feat: better integrate nnue train pipeline

This commit is contained in:
dogeystamp 2024-12-30 15:39:56 -05:00
parent e18464eceb
commit 4bd2fd1d9a
No known key found for this signature in database
3 changed files with 113 additions and 31 deletions

20
nnue/README.md Normal file
View File

@ -0,0 +1,20 @@
# NNUE training tools
Python training pipeline for the evaluation neural network.
See the docstring in `src/nnue.rs` for information about the architecture of the NNUE.
The network is trained on both self-play games, and its games on Lichess.
Both of these sources provide games in PGN format.
This folder includes the following scripts:
- `batch_pgn_data.py`: Combine and convert big PGN files into small chunked files.
- `process_pgn_data.py`: Convert PGN data into a format suitable for training.
Example training pipeline:
```bash
# chunk all the PGN files in `games/`. outputs by default to `batches/batch%d.pgn`.
./batch_pgn_data.py games/*.pgn
# analyze batches 0 to 20 to turn them into training data. outputs by default to train_data/batch%d.tsv.gz.
# set max-workers to the number of hardware threads / cores you have.
./process_pgn_data.py --engine ../target/release/chess_inator --max-workers 8 batches/batch{0..20}.pgn
```

4
nnue/batch_pgn_data.py Normal file → Executable file
View File

@ -43,8 +43,8 @@ def batch_games():
output_folder: Path = args.output_folder output_folder: Path = args.output_folder
output_folder.mkdir(exist_ok=True) output_folder.mkdir(exist_ok=True)
for idx, batch in enumerate(itertools.batched(generate_games(), args.batch_size)): for idx, batch in enumerate(itertools.batched(generate_games(), args.batch_size)):
with (output_folder / f"batch{idx:04}.pgn").open("w") as f: with (output_folder / f"batch{idx}.pgn").open("w") as f:
for game in batch: for game in batch:
f.write(str(game) + "\n") f.write(str(game) + "\n\n")
batch_games() batch_games()

102
nnue/process_pgn_data.py Normal file → Executable file
View File

@ -2,7 +2,6 @@
""" """
Processes PGN game data into a tsv format suitable for training. Processes PGN game data into a tsv format suitable for training.
Inputs from stdin, outputs to stdout.
Output columns: Output columns:
- FEN (for reference) - FEN (for reference)
@ -27,18 +26,26 @@ from asyncio import Queue, TaskGroup, create_task, run, sleep
import logging import logging
import datetime import datetime
import multiprocessing import multiprocessing
import gzip
import csv import csv
import chess import chess
import chess.engine import chess.engine
from typing import AsyncIterator, Literal from typing import AsyncIterator, Literal
from chess import pgn from chess import pgn
from sys import stdin, stdout
from pathlib import Path from pathlib import Path
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
) )
parser.add_argument(
"--log",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
default="INFO",
help="Sets log level.",
)
parser.add_argument( parser.add_argument(
"--engine", "--engine",
help="Set the file path of the chess_inator engine used to analyze the positions.", help="Set the file path of the chess_inator engine used to analyze the positions.",
@ -50,25 +57,39 @@ parser.add_argument(
default=min(4, multiprocessing.cpu_count()), default=min(4, multiprocessing.cpu_count()),
type=int, type=int,
) )
parser.add_argument(
"--preserve-partial",
action="store_true",
help="Keep output files that have not been fully written. These files may confuse this script when resuming operations.",
)
parser.add_argument("files", nargs="+", type=Path)
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=getattr(logging, str.upper(args.log)))
"""Skip these many plies from the start (avoid training on opening).""" """Skip these many plies from the start (avoid training on opening)."""
SKIP_PLIES: int = 20 SKIP_PLIES: int = 20
"""Time limit in seconds for each position to be analyzed.""" """Time limit in seconds for each position to be analyzed."""
TIME_LIMIT: float = 5 TIME_LIMIT: float = 3
output_queue: Queue[tuple[str, str, int, Literal[-1, 0, 1]]] = Queue() output_queue: Queue[tuple[str, str, int, Literal[-1, 0, 1]]] = Queue()
async def load_games(): # stats for progress
completed = 0
discarded = 0
current_outp: Path | None = None
start_time = datetime.datetime.now()
async def load_games(file: Path):
"""Load a PGN file and divide up the games for the workers to process.""" """Load a PGN file and divide up the games for the workers to process."""
while game := pgn.read_game(stdin): with open(file) as f:
while game := pgn.read_game(f):
yield game yield game
@ -101,7 +122,14 @@ async def worker(game_generator: AsyncIterator[pgn.Game]) -> None:
skipped = 0 skipped = 0
logging.info("Processing game %s, %s (%s) between %s as White and %s as Black.", game.headers["Event"], game.headers["Site"], game.headers["Date"], game.headers["White"], game.headers["Black"]) logging.info(
"Processing game %s, %s (%s) between %s as White and %s as Black.",
game.headers["Event"],
game.headers["Site"],
game.headers["Date"],
game.headers["White"],
game.headers["Black"],
)
for move in game.mainline_moves(): for move in game.mainline_moves():
board.push(move) board.push(move)
@ -123,6 +151,8 @@ async def worker(game_generator: AsyncIterator[pgn.Game]) -> None:
raise RuntimeError(f"Unexpected output from engine: {info_str}") raise RuntimeError(f"Unexpected output from engine: {info_str}")
if quiet == "non-quiet": if quiet == "non-quiet":
global discarded
discarded += 1
logging.debug("discarded as non-quiet: '%s'", board.fen()) logging.debug("discarded as non-quiet: '%s'", board.fen())
continue continue
elif quiet != "quiet": elif quiet != "quiet":
@ -131,9 +161,9 @@ async def worker(game_generator: AsyncIterator[pgn.Game]) -> None:
await output_queue.put((board.fen(), tensor, int(eval_abs), wdl)) await output_queue.put((board.fen(), tensor, int(eval_abs), wdl))
async def analyse_games(): async def analyse_games(file: Path):
"""Task that manages reading PGNs and analyzing them.""" """Task that manages reading PGNs and analyzing them."""
games_generator = load_games() games_generator = load_games(file)
async with TaskGroup() as tg: async with TaskGroup() as tg:
worker_count: int = min(args.max_workers, multiprocessing.cpu_count()) worker_count: int = min(args.max_workers, multiprocessing.cpu_count())
@ -142,18 +172,14 @@ async def analyse_games():
tg.create_task(worker(games_generator)) tg.create_task(worker(games_generator))
completed = 0 async def output_rows(outp_file: Path):
start_time = datetime.datetime.now()
async def output_rows():
"""TSV writer task.""" """TSV writer task."""
writer = csv.writer(stdout, delimiter="\t") with gzip.open(outp_file, "wt") as f:
writer = csv.writer(f, delimiter="\t")
while True: while True:
row = await output_queue.get() row = await output_queue.get()
writer.writerow(row) writer.writerow(row)
stdout.flush()
output_queue.task_done() output_queue.task_done()
global completed global completed
completed += 1 completed += 1
@ -164,20 +190,56 @@ async def status_logger():
while True: while True:
await sleep(5) await sleep(5)
logging.info( logging.info(
"Completed %d rows in %f seconds.", "Completed %d rows in %f seconds. Discarded %d non-quiet positions.",
completed, completed,
(datetime.datetime.now() - start_time).total_seconds(), (datetime.datetime.now() - start_time).total_seconds(),
discarded,
) )
async def main(): async def main():
analyse_task = create_task(analyse_games())
output_task = create_task(output_rows())
status_task = create_task(status_logger()) status_task = create_task(status_logger())
outp_dir = Path("train_data")
outp_dir.mkdir(exist_ok=True)
any_file = False
skipped = False
for file in args.files:
file: Path
outp_file = outp_dir / file.with_suffix(".tsv.gz").name
if outp_file.exists():
skipped = True
continue
any_file = True
if skipped:
logging.info("Resuming at file '%s'.", file)
else:
logging.info("Reading file '%s'.", file)
global current_outp
current_outp = outp_file
output_task = create_task(output_rows(outp_file))
analyse_task = create_task(analyse_games(file))
await analyse_task await analyse_task
output_task.cancel() output_task.cancel()
if not any_file:
logging.warning("Nothing to do. All input files have outputs already.")
status_task.cancel() status_task.cancel()
run(main()) try:
run(main())
except KeyboardInterrupt:
logging.critical("shutting down.")
if current_outp and not args.preserve_partial:
logging.critical("discarding partial output file %s", current_outp)
current_outp.unlink()