From 4bd2fd1d9acf22045a66daaddd80194ec8e2b661 Mon Sep 17 00:00:00 2001 From: dogeystamp Date: Mon, 30 Dec 2024 15:39:56 -0500 Subject: [PATCH] feat: better integrate nnue train pipeline --- nnue/README.md | 20 +++++++ nnue/batch_pgn_data.py | 4 +- nnue/process_pgn_data.py | 120 +++++++++++++++++++++++++++++---------- 3 files changed, 113 insertions(+), 31 deletions(-) create mode 100644 nnue/README.md mode change 100644 => 100755 nnue/batch_pgn_data.py mode change 100644 => 100755 nnue/process_pgn_data.py diff --git a/nnue/README.md b/nnue/README.md new file mode 100644 index 0000000..a3acf5b --- /dev/null +++ b/nnue/README.md @@ -0,0 +1,20 @@ +# NNUE training tools + +Python training pipeline for the evaluation neural network. +See the docstring in `src/nnue.rs` for information about the architecture of the NNUE. +The network is trained on both self-play games, and its games on Lichess. +Both of these sources provide games in PGN format. + +This folder includes the following scripts: +- `batch_pgn_data.py`: Combine and convert big PGN files into small chunked files. +- `process_pgn_data.py`: Convert PGN data into a format suitable for training. + +Example training pipeline: +```bash +# chunk all the PGN files in `games/`. outputs by default to `batches/batch%d.pgn`. +./batch_pgn_data.py games/*.pgn + +# analyze batches 0 to 20 to turn them into training data. outputs by default to train_data/batch%d.tsv.gz. +# set max-workers to the number of hardware threads / cores you have. +./process_pgn_data.py --engine ../target/release/chess_inator --max-workers 8 batches/batch{0..20}.pgn +``` diff --git a/nnue/batch_pgn_data.py b/nnue/batch_pgn_data.py old mode 100644 new mode 100755 index 94fddd1..3df1e6e --- a/nnue/batch_pgn_data.py +++ b/nnue/batch_pgn_data.py @@ -43,8 +43,8 @@ def batch_games(): output_folder: Path = args.output_folder output_folder.mkdir(exist_ok=True) for idx, batch in enumerate(itertools.batched(generate_games(), args.batch_size)): - with (output_folder / f"batch{idx:04}.pgn").open("w") as f: + with (output_folder / f"batch{idx}.pgn").open("w") as f: for game in batch: - f.write(str(game) + "\n") + f.write(str(game) + "\n\n") batch_games() diff --git a/nnue/process_pgn_data.py b/nnue/process_pgn_data.py old mode 100644 new mode 100755 index d177da1..c3790e8 --- a/nnue/process_pgn_data.py +++ b/nnue/process_pgn_data.py @@ -2,7 +2,6 @@ """ Processes PGN game data into a tsv format suitable for training. -Inputs from stdin, outputs to stdout. Output columns: - FEN (for reference) @@ -27,18 +26,26 @@ from asyncio import Queue, TaskGroup, create_task, run, sleep import logging import datetime import multiprocessing +import gzip import csv import chess import chess.engine from typing import AsyncIterator, Literal from chess import pgn -from sys import stdin, stdout from pathlib import Path parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter ) + +parser.add_argument( + "--log", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + default="INFO", + help="Sets log level.", +) + parser.add_argument( "--engine", help="Set the file path of the chess_inator engine used to analyze the positions.", @@ -50,26 +57,40 @@ parser.add_argument( default=min(4, multiprocessing.cpu_count()), type=int, ) +parser.add_argument( + "--preserve-partial", + action="store_true", + help="Keep output files that have not been fully written. These files may confuse this script when resuming operations.", +) +parser.add_argument("files", nargs="+", type=Path) args = parser.parse_args() -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=getattr(logging, str.upper(args.log))) """Skip these many plies from the start (avoid training on opening).""" SKIP_PLIES: int = 20 """Time limit in seconds for each position to be analyzed.""" -TIME_LIMIT: float = 5 +TIME_LIMIT: float = 3 output_queue: Queue[tuple[str, str, int, Literal[-1, 0, 1]]] = Queue() -async def load_games(): +# stats for progress +completed = 0 +discarded = 0 +current_outp: Path | None = None +start_time = datetime.datetime.now() + + +async def load_games(file: Path): """Load a PGN file and divide up the games for the workers to process.""" - while game := pgn.read_game(stdin): - yield game + with open(file) as f: + while game := pgn.read_game(f): + yield game async def worker(game_generator: AsyncIterator[pgn.Game]) -> None: @@ -101,7 +122,14 @@ async def worker(game_generator: AsyncIterator[pgn.Game]) -> None: skipped = 0 - logging.info("Processing game %s, %s (%s) between %s as White and %s as Black.", game.headers["Event"], game.headers["Site"], game.headers["Date"], game.headers["White"], game.headers["Black"]) + logging.info( + "Processing game %s, %s (%s) between %s as White and %s as Black.", + game.headers["Event"], + game.headers["Site"], + game.headers["Date"], + game.headers["White"], + game.headers["Black"], + ) for move in game.mainline_moves(): board.push(move) @@ -123,6 +151,8 @@ async def worker(game_generator: AsyncIterator[pgn.Game]) -> None: raise RuntimeError(f"Unexpected output from engine: {info_str}") if quiet == "non-quiet": + global discarded + discarded += 1 logging.debug("discarded as non-quiet: '%s'", board.fen()) continue elif quiet != "quiet": @@ -131,9 +161,9 @@ async def worker(game_generator: AsyncIterator[pgn.Game]) -> None: await output_queue.put((board.fen(), tensor, int(eval_abs), wdl)) -async def analyse_games(): +async def analyse_games(file: Path): """Task that manages reading PGNs and analyzing them.""" - games_generator = load_games() + games_generator = load_games(file) async with TaskGroup() as tg: worker_count: int = min(args.max_workers, multiprocessing.cpu_count()) @@ -142,21 +172,17 @@ async def analyse_games(): tg.create_task(worker(games_generator)) -completed = 0 -start_time = datetime.datetime.now() - - -async def output_rows(): +async def output_rows(outp_file: Path): """TSV writer task.""" - writer = csv.writer(stdout, delimiter="\t") - while True: - row = await output_queue.get() - writer.writerow(row) - stdout.flush() - output_queue.task_done() - global completed - completed += 1 + with gzip.open(outp_file, "wt") as f: + writer = csv.writer(f, delimiter="\t") + while True: + row = await output_queue.get() + writer.writerow(row) + output_queue.task_done() + global completed + completed += 1 async def status_logger(): @@ -164,20 +190,56 @@ async def status_logger(): while True: await sleep(5) logging.info( - "Completed %d rows in %f seconds.", + "Completed %d rows in %f seconds. Discarded %d non-quiet positions.", completed, (datetime.datetime.now() - start_time).total_seconds(), + discarded, ) async def main(): - analyse_task = create_task(analyse_games()) - output_task = create_task(output_rows()) status_task = create_task(status_logger()) - await analyse_task - output_task.cancel() + outp_dir = Path("train_data") + outp_dir.mkdir(exist_ok=True) + + any_file = False + skipped = False + + for file in args.files: + file: Path + + outp_file = outp_dir / file.with_suffix(".tsv.gz").name + + if outp_file.exists(): + skipped = True + continue + + any_file = True + + if skipped: + logging.info("Resuming at file '%s'.", file) + else: + logging.info("Reading file '%s'.", file) + + global current_outp + current_outp = outp_file + + output_task = create_task(output_rows(outp_file)) + analyse_task = create_task(analyse_games(file)) + await analyse_task + output_task.cancel() + + if not any_file: + logging.warning("Nothing to do. All input files have outputs already.") + status_task.cancel() -run(main()) +try: + run(main()) +except KeyboardInterrupt: + logging.critical("shutting down.") + if current_outp and not args.preserve_partial: + logging.critical("discarding partial output file %s", current_outp) + current_outp.unlink()