feat: nnue training data pipeline tools

This commit is contained in:
dogeystamp 2024-12-30 12:21:07 -05:00
parent 3885bd7948
commit cb65671444
No known key found for this signature in database
3 changed files with 235 additions and 0 deletions

2
.gitignore vendored
View File

@ -1,2 +1,4 @@
/target
TODO.txt
nnue/batches
nnue/venv

50
nnue/batch_pgn_data.py Normal file
View File

@ -0,0 +1,50 @@
#!/usr/bin/env python
"""
Batch PGN data into files, since the training data pipeline can't resume processing within a single file.
"""
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "chess",
# ]
# ///
from typing import Iterator
import chess.pgn
import argparse
import itertools
from pathlib import Path
"""Games to include per file in output."""
parser = argparse.ArgumentParser()
parser.add_argument("files", nargs="+", type=Path)
parser.add_argument("--batch-size", type=int, help="Number of games to save in each output file.", default=8)
parser.add_argument("--output-folder", type=Path, help="Folder to save batched games in.", default=Path("batches"))
args = parser.parse_args()
def generate_games_in_file(path: Path) -> Iterator[chess.pgn.Game]:
"""Read games from a single PGN file."""
with open(path) as f:
while game := chess.pgn.read_game(f):
game.headers["PGNPath"] = str(path)
yield game
def generate_games() -> Iterator[chess.pgn.Game]:
"""Read games from all files."""
for path in args.files:
yield from generate_games_in_file(path)
def batch_games():
"""Write games in batches."""
output_folder: Path = args.output_folder
output_folder.mkdir(exist_ok=True)
for idx, batch in enumerate(itertools.batched(generate_games(), args.batch_size)):
with (output_folder / f"batch{idx:04}.pgn").open("w") as f:
for game in batch:
f.write(str(game) + "\n")
batch_games()

183
nnue/process_pgn_data.py Normal file
View File

@ -0,0 +1,183 @@
#!/usr/bin/env python
"""
Processes PGN game data into a tsv format suitable for training.
Inputs from stdin, outputs to stdout.
Output columns:
- FEN (for reference)
- ALL 768-bit binary string representing the position
- Evaluation (centipawns) from white perspective
- Result of the game (-1, 0, 1)
This script depends on the `chess` package.
Install it, or run this script using `pipx run process_pgn_data.py`.
The script also depends on the chess_inator engine for analysis and filtering.
"""
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "chess",
# ]
# ///
import argparse
from asyncio import Queue, TaskGroup, create_task, run, sleep
import logging
import datetime
import multiprocessing
import csv
import chess
import chess.engine
from typing import AsyncIterator, Literal
from chess import pgn
from sys import stdin, stdout
from pathlib import Path
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument(
"--engine",
help="Set the file path of the chess_inator engine used to analyze the positions.",
type=Path,
)
parser.add_argument(
"--max-workers",
help="Max concurrent workers to analyse games with (limit this to your hardware thread count).",
default=min(4, multiprocessing.cpu_count()),
type=int,
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
"""Skip these many plies from the start (avoid training on opening)."""
SKIP_PLIES: int = 20
"""Time limit in seconds for each position to be analyzed."""
TIME_LIMIT: float = 5
output_queue: Queue[tuple[str, str, int, Literal[-1, 0, 1]]] = Queue()
async def load_games():
"""Load a PGN file and divide up the games for the workers to process."""
while game := pgn.read_game(stdin):
yield game
async def worker(game_generator: AsyncIterator[pgn.Game]) -> None:
"""
Single worker that analyzes whole games.
Code pattern taken from https://stackoverflow.com/a/54975674.
Puts rows of output into a global queue.
"""
transport, engine = await chess.engine.popen_uci(args.engine)
await engine.configure(dict(NNUETrainInfo="true"))
async for game in game_generator:
wdl: int | None = None
match game.headers["Result"]:
case "1-0":
wdl = 1
case "0-1":
wdl = -1
case "1/2-1/2":
wdl = 0
case other_result:
logging.error("invalid 'Result' header: '%s'", other_result)
continue
board = game.board()
skipped = 0
logging.info("Processing game %s, %s (%s) between %s as White and %s as Black.", game.headers["Event"], game.headers["Site"], game.headers["Date"], game.headers["White"], game.headers["Black"])
for move in game.mainline_moves():
board.push(move)
if skipped < SKIP_PLIES:
skipped += 1
continue
result = await engine.play(
board,
chess.engine.Limit(time=TIME_LIMIT),
info=chess.engine.INFO_ALL,
game=game,
)
info_str = result.info.get("string")
if not info_str:
raise RuntimeError("Could not analyze position with engine.")
(name, quiet, eval_abs, tensor) = info_str.split()
if not name == "NNUETrainInfo":
raise RuntimeError(f"Unexpected output from engine: {info_str}")
if quiet == "non-quiet":
logging.debug("discarded as non-quiet: '%s'", board.fen())
continue
elif quiet != "quiet":
raise RuntimeError(f"Unexpected output from engine: {info_str}")
await output_queue.put((board.fen(), tensor, int(eval_abs), wdl))
async def analyse_games():
"""Task that manages reading PGNs and analyzing them."""
games_generator = load_games()
async with TaskGroup() as tg:
worker_count: int = min(args.max_workers, multiprocessing.cpu_count())
logging.info("Using %d concurrent worker tasks.", worker_count)
for i in range(worker_count):
tg.create_task(worker(games_generator))
completed = 0
start_time = datetime.datetime.now()
async def output_rows():
"""TSV writer task."""
writer = csv.writer(stdout, delimiter="\t")
while True:
row = await output_queue.get()
writer.writerow(row)
stdout.flush()
output_queue.task_done()
global completed
completed += 1
async def status_logger():
"""Periodically print status."""
while True:
await sleep(5)
logging.info(
"Completed %d rows in %f seconds.",
completed,
(datetime.datetime.now() - start_time).total_seconds(),
)
async def main():
analyse_task = create_task(analyse_games())
output_task = create_task(output_rows())
status_task = create_task(status_logger())
await analyse_task
output_task.cancel()
status_task.cancel()
run(main())