feat: nnue training data pipeline tools
This commit is contained in:
parent
3885bd7948
commit
cb65671444
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,2 +1,4 @@
|
||||
/target
|
||||
TODO.txt
|
||||
nnue/batches
|
||||
nnue/venv
|
||||
|
50
nnue/batch_pgn_data.py
Normal file
50
nnue/batch_pgn_data.py
Normal file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Batch PGN data into files, since the training data pipeline can't resume processing within a single file.
|
||||
"""
|
||||
|
||||
# /// script
|
||||
# requires-python = ">=3.12"
|
||||
# dependencies = [
|
||||
# "chess",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
from typing import Iterator
|
||||
import chess.pgn
|
||||
import argparse
|
||||
import itertools
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
"""Games to include per file in output."""
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("files", nargs="+", type=Path)
|
||||
parser.add_argument("--batch-size", type=int, help="Number of games to save in each output file.", default=8)
|
||||
parser.add_argument("--output-folder", type=Path, help="Folder to save batched games in.", default=Path("batches"))
|
||||
args = parser.parse_args()
|
||||
|
||||
def generate_games_in_file(path: Path) -> Iterator[chess.pgn.Game]:
|
||||
"""Read games from a single PGN file."""
|
||||
with open(path) as f:
|
||||
while game := chess.pgn.read_game(f):
|
||||
game.headers["PGNPath"] = str(path)
|
||||
yield game
|
||||
|
||||
def generate_games() -> Iterator[chess.pgn.Game]:
|
||||
"""Read games from all files."""
|
||||
for path in args.files:
|
||||
yield from generate_games_in_file(path)
|
||||
|
||||
def batch_games():
|
||||
"""Write games in batches."""
|
||||
output_folder: Path = args.output_folder
|
||||
output_folder.mkdir(exist_ok=True)
|
||||
for idx, batch in enumerate(itertools.batched(generate_games(), args.batch_size)):
|
||||
with (output_folder / f"batch{idx:04}.pgn").open("w") as f:
|
||||
for game in batch:
|
||||
f.write(str(game) + "\n")
|
||||
|
||||
batch_games()
|
183
nnue/process_pgn_data.py
Normal file
183
nnue/process_pgn_data.py
Normal file
@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Processes PGN game data into a tsv format suitable for training.
|
||||
Inputs from stdin, outputs to stdout.
|
||||
|
||||
Output columns:
|
||||
- FEN (for reference)
|
||||
- ALL 768-bit binary string representing the position
|
||||
- Evaluation (centipawns) from white perspective
|
||||
- Result of the game (-1, 0, 1)
|
||||
|
||||
This script depends on the `chess` package.
|
||||
Install it, or run this script using `pipx run process_pgn_data.py`.
|
||||
The script also depends on the chess_inator engine for analysis and filtering.
|
||||
"""
|
||||
|
||||
# /// script
|
||||
# requires-python = ">=3.11"
|
||||
# dependencies = [
|
||||
# "chess",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import argparse
|
||||
from asyncio import Queue, TaskGroup, create_task, run, sleep
|
||||
import logging
|
||||
import datetime
|
||||
import multiprocessing
|
||||
import csv
|
||||
|
||||
import chess
|
||||
import chess.engine
|
||||
from typing import AsyncIterator, Literal
|
||||
from chess import pgn
|
||||
from sys import stdin, stdout
|
||||
from pathlib import Path
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
"--engine",
|
||||
help="Set the file path of the chess_inator engine used to analyze the positions.",
|
||||
type=Path,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-workers",
|
||||
help="Max concurrent workers to analyse games with (limit this to your hardware thread count).",
|
||||
default=min(4, multiprocessing.cpu_count()),
|
||||
type=int,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
"""Skip these many plies from the start (avoid training on opening)."""
|
||||
SKIP_PLIES: int = 20
|
||||
|
||||
"""Time limit in seconds for each position to be analyzed."""
|
||||
TIME_LIMIT: float = 5
|
||||
|
||||
|
||||
output_queue: Queue[tuple[str, str, int, Literal[-1, 0, 1]]] = Queue()
|
||||
|
||||
|
||||
async def load_games():
|
||||
"""Load a PGN file and divide up the games for the workers to process."""
|
||||
while game := pgn.read_game(stdin):
|
||||
yield game
|
||||
|
||||
|
||||
async def worker(game_generator: AsyncIterator[pgn.Game]) -> None:
|
||||
"""
|
||||
Single worker that analyzes whole games.
|
||||
|
||||
Code pattern taken from https://stackoverflow.com/a/54975674.
|
||||
|
||||
Puts rows of output into a global queue.
|
||||
"""
|
||||
transport, engine = await chess.engine.popen_uci(args.engine)
|
||||
await engine.configure(dict(NNUETrainInfo="true"))
|
||||
|
||||
async for game in game_generator:
|
||||
wdl: int | None = None
|
||||
|
||||
match game.headers["Result"]:
|
||||
case "1-0":
|
||||
wdl = 1
|
||||
case "0-1":
|
||||
wdl = -1
|
||||
case "1/2-1/2":
|
||||
wdl = 0
|
||||
case other_result:
|
||||
logging.error("invalid 'Result' header: '%s'", other_result)
|
||||
continue
|
||||
|
||||
board = game.board()
|
||||
|
||||
skipped = 0
|
||||
|
||||
logging.info("Processing game %s, %s (%s) between %s as White and %s as Black.", game.headers["Event"], game.headers["Site"], game.headers["Date"], game.headers["White"], game.headers["Black"])
|
||||
|
||||
for move in game.mainline_moves():
|
||||
board.push(move)
|
||||
if skipped < SKIP_PLIES:
|
||||
skipped += 1
|
||||
continue
|
||||
result = await engine.play(
|
||||
board,
|
||||
chess.engine.Limit(time=TIME_LIMIT),
|
||||
info=chess.engine.INFO_ALL,
|
||||
game=game,
|
||||
)
|
||||
|
||||
info_str = result.info.get("string")
|
||||
if not info_str:
|
||||
raise RuntimeError("Could not analyze position with engine.")
|
||||
(name, quiet, eval_abs, tensor) = info_str.split()
|
||||
if not name == "NNUETrainInfo":
|
||||
raise RuntimeError(f"Unexpected output from engine: {info_str}")
|
||||
|
||||
if quiet == "non-quiet":
|
||||
logging.debug("discarded as non-quiet: '%s'", board.fen())
|
||||
continue
|
||||
elif quiet != "quiet":
|
||||
raise RuntimeError(f"Unexpected output from engine: {info_str}")
|
||||
|
||||
await output_queue.put((board.fen(), tensor, int(eval_abs), wdl))
|
||||
|
||||
|
||||
async def analyse_games():
|
||||
"""Task that manages reading PGNs and analyzing them."""
|
||||
games_generator = load_games()
|
||||
|
||||
async with TaskGroup() as tg:
|
||||
worker_count: int = min(args.max_workers, multiprocessing.cpu_count())
|
||||
logging.info("Using %d concurrent worker tasks.", worker_count)
|
||||
for i in range(worker_count):
|
||||
tg.create_task(worker(games_generator))
|
||||
|
||||
|
||||
completed = 0
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
|
||||
async def output_rows():
|
||||
"""TSV writer task."""
|
||||
|
||||
writer = csv.writer(stdout, delimiter="\t")
|
||||
while True:
|
||||
row = await output_queue.get()
|
||||
writer.writerow(row)
|
||||
stdout.flush()
|
||||
output_queue.task_done()
|
||||
global completed
|
||||
completed += 1
|
||||
|
||||
|
||||
async def status_logger():
|
||||
"""Periodically print status."""
|
||||
while True:
|
||||
await sleep(5)
|
||||
logging.info(
|
||||
"Completed %d rows in %f seconds.",
|
||||
completed,
|
||||
(datetime.datetime.now() - start_time).total_seconds(),
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
analyse_task = create_task(analyse_games())
|
||||
output_task = create_task(output_rows())
|
||||
status_task = create_task(status_logger())
|
||||
|
||||
await analyse_task
|
||||
output_task.cancel()
|
||||
status_task.cancel()
|
||||
|
||||
|
||||
run(main())
|
Loading…
Reference in New Issue
Block a user