feat: better integrate nnue train pipeline
This commit is contained in:
parent
e18464eceb
commit
4bd2fd1d9a
20
nnue/README.md
Normal file
20
nnue/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
# NNUE training tools
|
||||
|
||||
Python training pipeline for the evaluation neural network.
|
||||
See the docstring in `src/nnue.rs` for information about the architecture of the NNUE.
|
||||
The network is trained on both self-play games, and its games on Lichess.
|
||||
Both of these sources provide games in PGN format.
|
||||
|
||||
This folder includes the following scripts:
|
||||
- `batch_pgn_data.py`: Combine and convert big PGN files into small chunked files.
|
||||
- `process_pgn_data.py`: Convert PGN data into a format suitable for training.
|
||||
|
||||
Example training pipeline:
|
||||
```bash
|
||||
# chunk all the PGN files in `games/`. outputs by default to `batches/batch%d.pgn`.
|
||||
./batch_pgn_data.py games/*.pgn
|
||||
|
||||
# analyze batches 0 to 20 to turn them into training data. outputs by default to train_data/batch%d.tsv.gz.
|
||||
# set max-workers to the number of hardware threads / cores you have.
|
||||
./process_pgn_data.py --engine ../target/release/chess_inator --max-workers 8 batches/batch{0..20}.pgn
|
||||
```
|
4
nnue/batch_pgn_data.py
Normal file → Executable file
4
nnue/batch_pgn_data.py
Normal file → Executable file
@ -43,8 +43,8 @@ def batch_games():
|
||||
output_folder: Path = args.output_folder
|
||||
output_folder.mkdir(exist_ok=True)
|
||||
for idx, batch in enumerate(itertools.batched(generate_games(), args.batch_size)):
|
||||
with (output_folder / f"batch{idx:04}.pgn").open("w") as f:
|
||||
with (output_folder / f"batch{idx}.pgn").open("w") as f:
|
||||
for game in batch:
|
||||
f.write(str(game) + "\n")
|
||||
f.write(str(game) + "\n\n")
|
||||
|
||||
batch_games()
|
||||
|
120
nnue/process_pgn_data.py
Normal file → Executable file
120
nnue/process_pgn_data.py
Normal file → Executable file
@ -2,7 +2,6 @@
|
||||
|
||||
"""
|
||||
Processes PGN game data into a tsv format suitable for training.
|
||||
Inputs from stdin, outputs to stdout.
|
||||
|
||||
Output columns:
|
||||
- FEN (for reference)
|
||||
@ -27,18 +26,26 @@ from asyncio import Queue, TaskGroup, create_task, run, sleep
|
||||
import logging
|
||||
import datetime
|
||||
import multiprocessing
|
||||
import gzip
|
||||
import csv
|
||||
|
||||
import chess
|
||||
import chess.engine
|
||||
from typing import AsyncIterator, Literal
|
||||
from chess import pgn
|
||||
from sys import stdin, stdout
|
||||
from pathlib import Path
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
default="INFO",
|
||||
help="Sets log level.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--engine",
|
||||
help="Set the file path of the chess_inator engine used to analyze the positions.",
|
||||
@ -50,26 +57,40 @@ parser.add_argument(
|
||||
default=min(4, multiprocessing.cpu_count()),
|
||||
type=int,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preserve-partial",
|
||||
action="store_true",
|
||||
help="Keep output files that have not been fully written. These files may confuse this script when resuming operations.",
|
||||
)
|
||||
parser.add_argument("files", nargs="+", type=Path)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.basicConfig(level=getattr(logging, str.upper(args.log)))
|
||||
|
||||
|
||||
"""Skip these many plies from the start (avoid training on opening)."""
|
||||
SKIP_PLIES: int = 20
|
||||
|
||||
"""Time limit in seconds for each position to be analyzed."""
|
||||
TIME_LIMIT: float = 5
|
||||
TIME_LIMIT: float = 3
|
||||
|
||||
|
||||
output_queue: Queue[tuple[str, str, int, Literal[-1, 0, 1]]] = Queue()
|
||||
|
||||
|
||||
async def load_games():
|
||||
# stats for progress
|
||||
completed = 0
|
||||
discarded = 0
|
||||
current_outp: Path | None = None
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
|
||||
async def load_games(file: Path):
|
||||
"""Load a PGN file and divide up the games for the workers to process."""
|
||||
while game := pgn.read_game(stdin):
|
||||
yield game
|
||||
with open(file) as f:
|
||||
while game := pgn.read_game(f):
|
||||
yield game
|
||||
|
||||
|
||||
async def worker(game_generator: AsyncIterator[pgn.Game]) -> None:
|
||||
@ -101,7 +122,14 @@ async def worker(game_generator: AsyncIterator[pgn.Game]) -> None:
|
||||
|
||||
skipped = 0
|
||||
|
||||
logging.info("Processing game %s, %s (%s) between %s as White and %s as Black.", game.headers["Event"], game.headers["Site"], game.headers["Date"], game.headers["White"], game.headers["Black"])
|
||||
logging.info(
|
||||
"Processing game %s, %s (%s) between %s as White and %s as Black.",
|
||||
game.headers["Event"],
|
||||
game.headers["Site"],
|
||||
game.headers["Date"],
|
||||
game.headers["White"],
|
||||
game.headers["Black"],
|
||||
)
|
||||
|
||||
for move in game.mainline_moves():
|
||||
board.push(move)
|
||||
@ -123,6 +151,8 @@ async def worker(game_generator: AsyncIterator[pgn.Game]) -> None:
|
||||
raise RuntimeError(f"Unexpected output from engine: {info_str}")
|
||||
|
||||
if quiet == "non-quiet":
|
||||
global discarded
|
||||
discarded += 1
|
||||
logging.debug("discarded as non-quiet: '%s'", board.fen())
|
||||
continue
|
||||
elif quiet != "quiet":
|
||||
@ -131,9 +161,9 @@ async def worker(game_generator: AsyncIterator[pgn.Game]) -> None:
|
||||
await output_queue.put((board.fen(), tensor, int(eval_abs), wdl))
|
||||
|
||||
|
||||
async def analyse_games():
|
||||
async def analyse_games(file: Path):
|
||||
"""Task that manages reading PGNs and analyzing them."""
|
||||
games_generator = load_games()
|
||||
games_generator = load_games(file)
|
||||
|
||||
async with TaskGroup() as tg:
|
||||
worker_count: int = min(args.max_workers, multiprocessing.cpu_count())
|
||||
@ -142,21 +172,17 @@ async def analyse_games():
|
||||
tg.create_task(worker(games_generator))
|
||||
|
||||
|
||||
completed = 0
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
|
||||
async def output_rows():
|
||||
async def output_rows(outp_file: Path):
|
||||
"""TSV writer task."""
|
||||
|
||||
writer = csv.writer(stdout, delimiter="\t")
|
||||
while True:
|
||||
row = await output_queue.get()
|
||||
writer.writerow(row)
|
||||
stdout.flush()
|
||||
output_queue.task_done()
|
||||
global completed
|
||||
completed += 1
|
||||
with gzip.open(outp_file, "wt") as f:
|
||||
writer = csv.writer(f, delimiter="\t")
|
||||
while True:
|
||||
row = await output_queue.get()
|
||||
writer.writerow(row)
|
||||
output_queue.task_done()
|
||||
global completed
|
||||
completed += 1
|
||||
|
||||
|
||||
async def status_logger():
|
||||
@ -164,20 +190,56 @@ async def status_logger():
|
||||
while True:
|
||||
await sleep(5)
|
||||
logging.info(
|
||||
"Completed %d rows in %f seconds.",
|
||||
"Completed %d rows in %f seconds. Discarded %d non-quiet positions.",
|
||||
completed,
|
||||
(datetime.datetime.now() - start_time).total_seconds(),
|
||||
discarded,
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
analyse_task = create_task(analyse_games())
|
||||
output_task = create_task(output_rows())
|
||||
status_task = create_task(status_logger())
|
||||
|
||||
await analyse_task
|
||||
output_task.cancel()
|
||||
outp_dir = Path("train_data")
|
||||
outp_dir.mkdir(exist_ok=True)
|
||||
|
||||
any_file = False
|
||||
skipped = False
|
||||
|
||||
for file in args.files:
|
||||
file: Path
|
||||
|
||||
outp_file = outp_dir / file.with_suffix(".tsv.gz").name
|
||||
|
||||
if outp_file.exists():
|
||||
skipped = True
|
||||
continue
|
||||
|
||||
any_file = True
|
||||
|
||||
if skipped:
|
||||
logging.info("Resuming at file '%s'.", file)
|
||||
else:
|
||||
logging.info("Reading file '%s'.", file)
|
||||
|
||||
global current_outp
|
||||
current_outp = outp_file
|
||||
|
||||
output_task = create_task(output_rows(outp_file))
|
||||
analyse_task = create_task(analyse_games(file))
|
||||
await analyse_task
|
||||
output_task.cancel()
|
||||
|
||||
if not any_file:
|
||||
logging.warning("Nothing to do. All input files have outputs already.")
|
||||
|
||||
status_task.cancel()
|
||||
|
||||
|
||||
run(main())
|
||||
try:
|
||||
run(main())
|
||||
except KeyboardInterrupt:
|
||||
logging.critical("shutting down.")
|
||||
if current_outp and not args.preserve_partial:
|
||||
logging.critical("discarding partial output file %s", current_outp)
|
||||
current_outp.unlink()
|
||||
|
Loading…
Reference in New Issue
Block a user