246 lines
6.6 KiB
Python
Executable File
246 lines
6.6 KiB
Python
Executable File
#!/usr/bin/env python
|
|
|
|
"""
|
|
Processes PGN game data into a tsv format suitable for training.
|
|
|
|
Output columns:
|
|
- FEN (for reference)
|
|
- ALL 768-bit binary string representing the position
|
|
- Evaluation (centipawns) from white perspective
|
|
- Result of the game (-1, 0, 1)
|
|
|
|
This script depends on the `chess` package.
|
|
Install it, or run this script using `pipx run process_pgn_data.py`.
|
|
The script also depends on the chess_inator engine for analysis and filtering.
|
|
"""
|
|
|
|
# /// script
|
|
# requires-python = ">=3.11"
|
|
# dependencies = [
|
|
# "chess",
|
|
# ]
|
|
# ///
|
|
|
|
import argparse
|
|
from asyncio import Queue, TaskGroup, create_task, run, sleep
|
|
import logging
|
|
import datetime
|
|
import multiprocessing
|
|
import gzip
|
|
import csv
|
|
|
|
import chess
|
|
import chess.engine
|
|
from typing import AsyncIterator, Literal
|
|
from chess import pgn
|
|
from pathlib import Path
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--log",
|
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
|
default="INFO",
|
|
help="Sets log level.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--engine",
|
|
help="Set the file path of the chess_inator engine used to analyze the positions.",
|
|
type=Path,
|
|
)
|
|
parser.add_argument(
|
|
"--max-workers",
|
|
help="Max concurrent workers to analyse games with (limit this to your hardware thread count).",
|
|
default=min(4, multiprocessing.cpu_count()),
|
|
type=int,
|
|
)
|
|
parser.add_argument(
|
|
"--preserve-partial",
|
|
action="store_true",
|
|
help="Keep output files that have not been fully written. These files may confuse this script when resuming operations.",
|
|
)
|
|
parser.add_argument("files", nargs="+", type=Path)
|
|
args = parser.parse_args()
|
|
|
|
|
|
logging.basicConfig(level=getattr(logging, str.upper(args.log)))
|
|
|
|
|
|
"""Skip these many plies from the start (avoid training on opening)."""
|
|
SKIP_PLIES: int = 20
|
|
|
|
"""Time limit in seconds for each position to be analyzed."""
|
|
TIME_LIMIT: float = 3
|
|
|
|
|
|
output_queue: Queue[tuple[str, str, int, Literal[-1, 0, 1]]] = Queue()
|
|
|
|
|
|
# stats for progress
|
|
completed = 0
|
|
discarded = 0
|
|
current_outp: Path | None = None
|
|
start_time = datetime.datetime.now()
|
|
|
|
|
|
async def load_games(file: Path):
|
|
"""Load a PGN file and divide up the games for the workers to process."""
|
|
with open(file) as f:
|
|
while game := pgn.read_game(f):
|
|
yield game
|
|
|
|
|
|
async def worker(game_generator: AsyncIterator[pgn.Game]) -> None:
|
|
"""
|
|
Single worker that analyzes whole games.
|
|
|
|
Code pattern taken from https://stackoverflow.com/a/54975674.
|
|
|
|
Puts rows of output into a global queue.
|
|
"""
|
|
transport, engine = await chess.engine.popen_uci(args.engine)
|
|
await engine.configure(dict(NNUETrainInfo="true"))
|
|
|
|
async for game in game_generator:
|
|
wdl: int | None = None
|
|
|
|
match game.headers["Result"]:
|
|
case "1-0":
|
|
wdl = 1
|
|
case "0-1":
|
|
wdl = -1
|
|
case "1/2-1/2":
|
|
wdl = 0
|
|
case other_result:
|
|
logging.error("invalid 'Result' header: '%s'", other_result)
|
|
continue
|
|
|
|
board = game.board()
|
|
|
|
skipped = 0
|
|
|
|
logging.info(
|
|
"Processing game %s, %s (%s) between %s as White and %s as Black.",
|
|
game.headers["Event"],
|
|
game.headers["Site"],
|
|
game.headers["Date"],
|
|
game.headers["White"],
|
|
game.headers["Black"],
|
|
)
|
|
|
|
for move in game.mainline_moves():
|
|
board.push(move)
|
|
if skipped < SKIP_PLIES:
|
|
skipped += 1
|
|
continue
|
|
result = await engine.play(
|
|
board,
|
|
chess.engine.Limit(time=TIME_LIMIT),
|
|
info=chess.engine.INFO_ALL,
|
|
game=game,
|
|
)
|
|
|
|
info_str = result.info.get("string")
|
|
if not info_str:
|
|
raise RuntimeError("Could not analyze position with engine.")
|
|
(name, quiet, eval_abs, tensor) = info_str.split()
|
|
if not name == "NNUETrainInfo":
|
|
raise RuntimeError(f"Unexpected output from engine: {info_str}")
|
|
|
|
if quiet == "non-quiet":
|
|
global discarded
|
|
discarded += 1
|
|
logging.debug("discarded as non-quiet: '%s'", board.fen())
|
|
continue
|
|
elif quiet != "quiet":
|
|
raise RuntimeError(f"Unexpected output from engine: {info_str}")
|
|
|
|
await output_queue.put((board.fen(), tensor, int(eval_abs), wdl))
|
|
|
|
|
|
async def analyse_games(file: Path):
|
|
"""Task that manages reading PGNs and analyzing them."""
|
|
games_generator = load_games(file)
|
|
|
|
async with TaskGroup() as tg:
|
|
worker_count: int = min(args.max_workers, multiprocessing.cpu_count())
|
|
logging.info("Using %d concurrent worker tasks.", worker_count)
|
|
for i in range(worker_count):
|
|
tg.create_task(worker(games_generator))
|
|
|
|
|
|
async def output_rows(outp_file: Path):
|
|
"""TSV writer task."""
|
|
|
|
with gzip.open(outp_file, "wt") as f:
|
|
writer = csv.writer(f, delimiter="\t")
|
|
while True:
|
|
row = await output_queue.get()
|
|
writer.writerow(row)
|
|
output_queue.task_done()
|
|
global completed
|
|
completed += 1
|
|
|
|
|
|
async def status_logger():
|
|
"""Periodically print status."""
|
|
while True:
|
|
await sleep(5)
|
|
logging.info(
|
|
"Completed %d rows in %f seconds. Discarded %d non-quiet positions.",
|
|
completed,
|
|
(datetime.datetime.now() - start_time).total_seconds(),
|
|
discarded,
|
|
)
|
|
|
|
|
|
async def main():
|
|
status_task = create_task(status_logger())
|
|
|
|
outp_dir = Path("train_data")
|
|
outp_dir.mkdir(exist_ok=True)
|
|
|
|
any_file = False
|
|
skipped = False
|
|
|
|
for file in args.files:
|
|
file: Path
|
|
|
|
outp_file = outp_dir / file.with_suffix(".tsv.gz").name
|
|
|
|
if outp_file.exists():
|
|
skipped = True
|
|
continue
|
|
|
|
any_file = True
|
|
|
|
if skipped:
|
|
logging.info("Resuming at file '%s'.", file)
|
|
else:
|
|
logging.info("Reading file '%s'.", file)
|
|
|
|
global current_outp
|
|
current_outp = outp_file
|
|
|
|
output_task = create_task(output_rows(outp_file))
|
|
analyse_task = create_task(analyse_games(file))
|
|
await analyse_task
|
|
output_task.cancel()
|
|
|
|
if not any_file:
|
|
logging.warning("Nothing to do. All input files have outputs already.")
|
|
|
|
status_task.cancel()
|
|
|
|
|
|
try:
|
|
run(main())
|
|
except KeyboardInterrupt:
|
|
logging.critical("shutting down.")
|
|
if current_outp and not args.preserve_partial:
|
|
logging.critical("discarding partial output file %s", current_outp)
|
|
current_outp.unlink()
|