feat: better integrate nnue train pipeline

This commit is contained in:
dogeystamp 2024-12-30 15:39:56 -05:00
parent e18464eceb
commit 4bd2fd1d9a
No known key found for this signature in database
3 changed files with 113 additions and 31 deletions

20
nnue/README.md Normal file
View File

@ -0,0 +1,20 @@
# NNUE training tools
Python training pipeline for the evaluation neural network.
See the docstring in `src/nnue.rs` for information about the architecture of the NNUE.
The network is trained on both self-play games, and its games on Lichess.
Both of these sources provide games in PGN format.
This folder includes the following scripts:
- `batch_pgn_data.py`: Combine and convert big PGN files into small chunked files.
- `process_pgn_data.py`: Convert PGN data into a format suitable for training.
Example training pipeline:
```bash
# chunk all the PGN files in `games/`. outputs by default to `batches/batch%d.pgn`.
./batch_pgn_data.py games/*.pgn
# analyze batches 0 to 20 to turn them into training data. outputs by default to train_data/batch%d.tsv.gz.
# set max-workers to the number of hardware threads / cores you have.
./process_pgn_data.py --engine ../target/release/chess_inator --max-workers 8 batches/batch{0..20}.pgn
```

4
nnue/batch_pgn_data.py Normal file → Executable file
View File

@ -43,8 +43,8 @@ def batch_games():
output_folder: Path = args.output_folder
output_folder.mkdir(exist_ok=True)
for idx, batch in enumerate(itertools.batched(generate_games(), args.batch_size)):
with (output_folder / f"batch{idx:04}.pgn").open("w") as f:
with (output_folder / f"batch{idx}.pgn").open("w") as f:
for game in batch:
f.write(str(game) + "\n")
f.write(str(game) + "\n\n")
batch_games()

120
nnue/process_pgn_data.py Normal file → Executable file
View File

@ -2,7 +2,6 @@
"""
Processes PGN game data into a tsv format suitable for training.
Inputs from stdin, outputs to stdout.
Output columns:
- FEN (for reference)
@ -27,18 +26,26 @@ from asyncio import Queue, TaskGroup, create_task, run, sleep
import logging
import datetime
import multiprocessing
import gzip
import csv
import chess
import chess.engine
from typing import AsyncIterator, Literal
from chess import pgn
from sys import stdin, stdout
from pathlib import Path
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument(
"--log",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
default="INFO",
help="Sets log level.",
)
parser.add_argument(
"--engine",
help="Set the file path of the chess_inator engine used to analyze the positions.",
@ -50,26 +57,40 @@ parser.add_argument(
default=min(4, multiprocessing.cpu_count()),
type=int,
)
parser.add_argument(
"--preserve-partial",
action="store_true",
help="Keep output files that have not been fully written. These files may confuse this script when resuming operations.",
)
parser.add_argument("files", nargs="+", type=Path)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=getattr(logging, str.upper(args.log)))
"""Skip these many plies from the start (avoid training on opening)."""
SKIP_PLIES: int = 20
"""Time limit in seconds for each position to be analyzed."""
TIME_LIMIT: float = 5
TIME_LIMIT: float = 3
output_queue: Queue[tuple[str, str, int, Literal[-1, 0, 1]]] = Queue()
async def load_games():
# stats for progress
completed = 0
discarded = 0
current_outp: Path | None = None
start_time = datetime.datetime.now()
async def load_games(file: Path):
"""Load a PGN file and divide up the games for the workers to process."""
while game := pgn.read_game(stdin):
yield game
with open(file) as f:
while game := pgn.read_game(f):
yield game
async def worker(game_generator: AsyncIterator[pgn.Game]) -> None:
@ -101,7 +122,14 @@ async def worker(game_generator: AsyncIterator[pgn.Game]) -> None:
skipped = 0
logging.info("Processing game %s, %s (%s) between %s as White and %s as Black.", game.headers["Event"], game.headers["Site"], game.headers["Date"], game.headers["White"], game.headers["Black"])
logging.info(
"Processing game %s, %s (%s) between %s as White and %s as Black.",
game.headers["Event"],
game.headers["Site"],
game.headers["Date"],
game.headers["White"],
game.headers["Black"],
)
for move in game.mainline_moves():
board.push(move)
@ -123,6 +151,8 @@ async def worker(game_generator: AsyncIterator[pgn.Game]) -> None:
raise RuntimeError(f"Unexpected output from engine: {info_str}")
if quiet == "non-quiet":
global discarded
discarded += 1
logging.debug("discarded as non-quiet: '%s'", board.fen())
continue
elif quiet != "quiet":
@ -131,9 +161,9 @@ async def worker(game_generator: AsyncIterator[pgn.Game]) -> None:
await output_queue.put((board.fen(), tensor, int(eval_abs), wdl))
async def analyse_games():
async def analyse_games(file: Path):
"""Task that manages reading PGNs and analyzing them."""
games_generator = load_games()
games_generator = load_games(file)
async with TaskGroup() as tg:
worker_count: int = min(args.max_workers, multiprocessing.cpu_count())
@ -142,21 +172,17 @@ async def analyse_games():
tg.create_task(worker(games_generator))
completed = 0
start_time = datetime.datetime.now()
async def output_rows():
async def output_rows(outp_file: Path):
"""TSV writer task."""
writer = csv.writer(stdout, delimiter="\t")
while True:
row = await output_queue.get()
writer.writerow(row)
stdout.flush()
output_queue.task_done()
global completed
completed += 1
with gzip.open(outp_file, "wt") as f:
writer = csv.writer(f, delimiter="\t")
while True:
row = await output_queue.get()
writer.writerow(row)
output_queue.task_done()
global completed
completed += 1
async def status_logger():
@ -164,20 +190,56 @@ async def status_logger():
while True:
await sleep(5)
logging.info(
"Completed %d rows in %f seconds.",
"Completed %d rows in %f seconds. Discarded %d non-quiet positions.",
completed,
(datetime.datetime.now() - start_time).total_seconds(),
discarded,
)
async def main():
analyse_task = create_task(analyse_games())
output_task = create_task(output_rows())
status_task = create_task(status_logger())
await analyse_task
output_task.cancel()
outp_dir = Path("train_data")
outp_dir.mkdir(exist_ok=True)
any_file = False
skipped = False
for file in args.files:
file: Path
outp_file = outp_dir / file.with_suffix(".tsv.gz").name
if outp_file.exists():
skipped = True
continue
any_file = True
if skipped:
logging.info("Resuming at file '%s'.", file)
else:
logging.info("Reading file '%s'.", file)
global current_outp
current_outp = outp_file
output_task = create_task(output_rows(outp_file))
analyse_task = create_task(analyse_games(file))
await analyse_task
output_task.cancel()
if not any_file:
logging.warning("Nothing to do. All input files have outputs already.")
status_task.cancel()
run(main())
try:
run(main())
except KeyboardInterrupt:
logging.critical("shutting down.")
if current_outp and not args.preserve_partial:
logging.critical("discarding partial output file %s", current_outp)
current_outp.unlink()