Compare commits

...

10 Commits

Author SHA1 Message Date
0ae4c889dc
tune: NNUE "quiet" threshold
the results aren't nearly as bonkers as before the bugfix in a47778a,
so tone down the threshold
2024-12-30 17:21:20 -05:00
a47778ae6c
fix: board not same after running minmax
caused by an early return from hard stop that bypasses unmake move
2024-12-30 17:08:27 -05:00
4bd2fd1d9a
feat: better integrate nnue train pipeline 2024-12-30 15:39:56 -05:00
e18464eceb
tune: search time increased
last version (44e4cab) suffers less from time pressure, but blunders,
possibly due to not thinking long enough
2024-12-30 13:59:47 -05:00
44e4cabdc1
tune: search time reduced
this should make it get less time pressure
2024-12-30 12:22:09 -05:00
cb65671444
feat: nnue training data pipeline tools 2024-12-30 12:21:07 -05:00
3885bd7948
feat: fast-chess-tag can save multiple PGNs for the same players
it will use number indices instead of asking to delete the first PGN.
2024-12-30 09:51:04 -05:00
612e6ffc15
perf: repetition draw avoidance now only looks at necessary moves
uses half-move counter.
2024-12-29 22:32:00 -05:00
9995f13693
fix: transposition table no longer instantly returns a result
this has issues like cutting off the PV line, and more importantly,
bypassing draw by repetition detection.
2024-12-29 18:44:38 -05:00
f6bf1b46c7
feat: nnue info now shows absolute eval 2024-12-29 18:04:39 -05:00
8 changed files with 446 additions and 45 deletions

2
.gitignore vendored
View File

@ -1,2 +1,4 @@
/target
TODO.txt
nnue/batches
nnue/venv

View File

@ -59,9 +59,13 @@ mkdir -p games
PGN=games/"$TAG1"__"$TAG2".pgn
rm -f engine1 engine2
if [ -f "$PGN" ]; then
rm -i "$PGN"
fi
IDX=1
while [ -e "$PGN" ]; do
PGN=games/"$TAG1"__"$TAG2"__"$IDX".pgn
IDX=$(( $IDX + 1 ))
done
printf "using pgn output: %s\n" "$PGN" > /dev/stderr
git checkout "$TAG1"
cargo build --release

20
nnue/README.md Normal file
View File

@ -0,0 +1,20 @@
# NNUE training tools
Python training pipeline for the evaluation neural network.
See the docstring in `src/nnue.rs` for information about the architecture of the NNUE.
The network is trained on both self-play games, and its games on Lichess.
Both of these sources provide games in PGN format.
This folder includes the following scripts:
- `batch_pgn_data.py`: Combine and convert big PGN files into small chunked files.
- `process_pgn_data.py`: Convert PGN data into a format suitable for training.
Example training pipeline:
```bash
# chunk all the PGN files in `games/`. outputs by default to `batches/batch%d.pgn`.
./batch_pgn_data.py games/*.pgn
# analyze batches 0 to 20 to turn them into training data. outputs by default to train_data/batch%d.tsv.gz.
# set max-workers to the number of hardware threads / cores you have.
./process_pgn_data.py --engine ../target/release/chess_inator --max-workers 8 batches/batch{0..20}.pgn
```

50
nnue/batch_pgn_data.py Executable file
View File

@ -0,0 +1,50 @@
#!/usr/bin/env python
"""
Batch PGN data into files, since the training data pipeline can't resume processing within a single file.
"""
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "chess",
# ]
# ///
from typing import Iterator
import chess.pgn
import argparse
import itertools
from pathlib import Path
"""Games to include per file in output."""
parser = argparse.ArgumentParser()
parser.add_argument("files", nargs="+", type=Path)
parser.add_argument("--batch-size", type=int, help="Number of games to save in each output file.", default=8)
parser.add_argument("--output-folder", type=Path, help="Folder to save batched games in.", default=Path("batches"))
args = parser.parse_args()
def generate_games_in_file(path: Path) -> Iterator[chess.pgn.Game]:
"""Read games from a single PGN file."""
with open(path) as f:
while game := chess.pgn.read_game(f):
game.headers["PGNPath"] = str(path)
yield game
def generate_games() -> Iterator[chess.pgn.Game]:
"""Read games from all files."""
for path in args.files:
yield from generate_games_in_file(path)
def batch_games():
"""Write games in batches."""
output_folder: Path = args.output_folder
output_folder.mkdir(exist_ok=True)
for idx, batch in enumerate(itertools.batched(generate_games(), args.batch_size)):
with (output_folder / f"batch{idx}.pgn").open("w") as f:
for game in batch:
f.write(str(game) + "\n\n")
batch_games()

245
nnue/process_pgn_data.py Executable file
View File

@ -0,0 +1,245 @@
#!/usr/bin/env python
"""
Processes PGN game data into a tsv format suitable for training.
Output columns:
- FEN (for reference)
- ALL 768-bit binary string representing the position
- Evaluation (centipawns) from white perspective
- Result of the game (-1, 0, 1)
This script depends on the `chess` package.
Install it, or run this script using `pipx run process_pgn_data.py`.
The script also depends on the chess_inator engine for analysis and filtering.
"""
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "chess",
# ]
# ///
import argparse
from asyncio import Queue, TaskGroup, create_task, run, sleep
import logging
import datetime
import multiprocessing
import gzip
import csv
import chess
import chess.engine
from typing import AsyncIterator, Literal
from chess import pgn
from pathlib import Path
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument(
"--log",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
default="INFO",
help="Sets log level.",
)
parser.add_argument(
"--engine",
help="Set the file path of the chess_inator engine used to analyze the positions.",
type=Path,
)
parser.add_argument(
"--max-workers",
help="Max concurrent workers to analyse games with (limit this to your hardware thread count).",
default=min(4, multiprocessing.cpu_count()),
type=int,
)
parser.add_argument(
"--preserve-partial",
action="store_true",
help="Keep output files that have not been fully written. These files may confuse this script when resuming operations.",
)
parser.add_argument("files", nargs="+", type=Path)
args = parser.parse_args()
logging.basicConfig(level=getattr(logging, str.upper(args.log)))
"""Skip these many plies from the start (avoid training on opening)."""
SKIP_PLIES: int = 20
"""Time limit in seconds for each position to be analyzed."""
TIME_LIMIT: float = 3
output_queue: Queue[tuple[str, str, int, Literal[-1, 0, 1]]] = Queue()
# stats for progress
completed = 0
discarded = 0
current_outp: Path | None = None
start_time = datetime.datetime.now()
async def load_games(file: Path):
"""Load a PGN file and divide up the games for the workers to process."""
with open(file) as f:
while game := pgn.read_game(f):
yield game
async def worker(game_generator: AsyncIterator[pgn.Game]) -> None:
"""
Single worker that analyzes whole games.
Code pattern taken from https://stackoverflow.com/a/54975674.
Puts rows of output into a global queue.
"""
transport, engine = await chess.engine.popen_uci(args.engine)
await engine.configure(dict(NNUETrainInfo="true"))
async for game in game_generator:
wdl: int | None = None
match game.headers["Result"]:
case "1-0":
wdl = 1
case "0-1":
wdl = -1
case "1/2-1/2":
wdl = 0
case other_result:
logging.error("invalid 'Result' header: '%s'", other_result)
continue
board = game.board()
skipped = 0
logging.info(
"Processing game %s, %s (%s) between %s as White and %s as Black.",
game.headers["Event"],
game.headers["Site"],
game.headers["Date"],
game.headers["White"],
game.headers["Black"],
)
for move in game.mainline_moves():
board.push(move)
if skipped < SKIP_PLIES:
skipped += 1
continue
result = await engine.play(
board,
chess.engine.Limit(time=TIME_LIMIT),
info=chess.engine.INFO_ALL,
game=game,
)
info_str = result.info.get("string")
if not info_str:
raise RuntimeError("Could not analyze position with engine.")
(name, quiet, eval_abs, tensor) = info_str.split()
if not name == "NNUETrainInfo":
raise RuntimeError(f"Unexpected output from engine: {info_str}")
if quiet == "non-quiet":
global discarded
discarded += 1
logging.debug("discarded as non-quiet: '%s'", board.fen())
continue
elif quiet != "quiet":
raise RuntimeError(f"Unexpected output from engine: {info_str}")
await output_queue.put((board.fen(), tensor, int(eval_abs), wdl))
async def analyse_games(file: Path):
"""Task that manages reading PGNs and analyzing them."""
games_generator = load_games(file)
async with TaskGroup() as tg:
worker_count: int = min(args.max_workers, multiprocessing.cpu_count())
logging.info("Using %d concurrent worker tasks.", worker_count)
for i in range(worker_count):
tg.create_task(worker(games_generator))
async def output_rows(outp_file: Path):
"""TSV writer task."""
with gzip.open(outp_file, "wt") as f:
writer = csv.writer(f, delimiter="\t")
while True:
row = await output_queue.get()
writer.writerow(row)
output_queue.task_done()
global completed
completed += 1
async def status_logger():
"""Periodically print status."""
while True:
await sleep(5)
logging.info(
"Completed %d rows in %f seconds. Discarded %d non-quiet positions.",
completed,
(datetime.datetime.now() - start_time).total_seconds(),
discarded,
)
async def main():
status_task = create_task(status_logger())
outp_dir = Path("train_data")
outp_dir.mkdir(exist_ok=True)
any_file = False
skipped = False
for file in args.files:
file: Path
outp_file = outp_dir / file.with_suffix(".tsv.gz").name
if outp_file.exists():
skipped = True
continue
any_file = True
if skipped:
logging.info("Resuming at file '%s'.", file)
else:
logging.info("Reading file '%s'.", file)
global current_outp
current_outp = outp_file
output_task = create_task(output_rows(outp_file))
analyse_task = create_task(analyse_games(file))
await analyse_task
output_task.cancel()
if not any_file:
logging.warning("Nothing to do. All input files have outputs already.")
status_task.cancel()
try:
run(main())
except KeyboardInterrupt:
logging.critical("shutting down.")
if current_outp and not args.preserve_partial:
logging.critical("discarding partial output file %s", current_outp)
current_outp.unlink()

View File

@ -19,10 +19,10 @@ use std::str::FromStr;
pub mod coordination;
pub mod eval;
pub mod nnue;
pub mod fen;
mod hash;
pub mod movegen;
pub mod nnue;
pub mod random;
pub mod search;
@ -554,7 +554,7 @@ const HISTORY_SIZE: usize = 100;
impl BoardHistory {
/// Counts occurences of this hash in the history.
fn count(&self, hash: Zobrist) -> usize {
fn _count(&self, hash: Zobrist) -> usize {
let mut ans = 0;
let mut i = self.ptr_start;
@ -568,6 +568,20 @@ impl BoardHistory {
ans
}
/// Find if there are at least `n` matches for a hash in the last `recent` plies.
fn at_least_in_recent(&self, mut n: usize, recent: usize, hash: Zobrist) -> bool {
let mut i = self.ptr_end - recent;
while i != self.ptr_end && n > 0 {
if self.hashes[usize::from(i)] == hash {
n -= 1;
}
i += 1;
}
n == 0
}
/// Add (push) hash to history.
fn push(&mut self, hash: Zobrist) {
self.hashes[usize::from(self.ptr_end)] = hash;
@ -633,6 +647,12 @@ impl Board {
self.history.push(self.zobrist);
}
/// Is this position a draw by three repetitions?
pub fn is_repetition(&mut self) -> bool {
self.history
.at_least_in_recent(2, self.half_moves, self.zobrist)
}
/// Get iterator over all squares.
pub fn squares() -> impl Iterator<Item = Square> {
(0..N_SQUARES).map(Square::try_from).map(|x| x.unwrap())
@ -948,20 +968,30 @@ mod tests {
history.push(board.zobrist);
}
assert_eq!(history.count(board.zobrist), HISTORY_SIZE - 1);
assert_eq!(history._count(board.zobrist), HISTORY_SIZE - 1);
assert!(history.at_least_in_recent(1, 1, board.zobrist));
assert!(history.at_least_in_recent(2, 3, board.zobrist));
assert!(history.at_least_in_recent(1, 3, board.zobrist));
let board_empty = Board::default();
history.push(board_empty.zobrist);
assert_eq!(history.count(board.zobrist), HISTORY_SIZE - 2);
assert_eq!(history.count(board_empty.zobrist), 1);
assert!(!history.at_least_in_recent(1, 1, board.zobrist));
assert!(history.at_least_in_recent(1, 2, board.zobrist));
assert_eq!(history._count(board.zobrist), HISTORY_SIZE - 2);
assert_eq!(history._count(board_empty.zobrist), 1);
assert!(history.at_least_in_recent(1, 3, board.zobrist));
assert!(history.at_least_in_recent(1, 20, board_empty.zobrist));
assert!(history.at_least_in_recent(1, 15, board_empty.zobrist));
assert!(history.at_least_in_recent(1, 1, board_empty.zobrist));
for _ in 0..3 {
history.push(board_empty.zobrist);
}
assert_eq!(history.count(board_empty.zobrist), 4);
assert_eq!(history.count(board.zobrist), HISTORY_SIZE - 5);
assert_eq!(history._count(board_empty.zobrist), 4);
assert_eq!(history._count(board.zobrist), HISTORY_SIZE - 5);
}
#[test]
@ -982,7 +1012,8 @@ mod tests {
}
// this is the third occurence, but beforehand there are two occurences
assert_eq!(board.history.count(board.zobrist), 2);
assert_eq!(board.history._count(board.zobrist), 2);
assert!(board.is_repetition(), "fen: {}", board.to_fen());
}
/// engine should take advantage of the three time repetition rule
@ -1032,7 +1063,7 @@ mod tests {
expected_bestmv.make(&mut board);
eprintln!(
"after expected mv, board repeated {} times",
board.history.count(board.zobrist)
board.history._count(board.zobrist)
);
assert_eq!(

View File

@ -370,19 +370,18 @@ fn task_engine(tx_main: Sender<MsgToMain>, rx_engine: Receiver<MsgToEngine>) {
let mut info: Vec<String> = Vec::new();
if state.config.nnue_train_info {
let is_quiet = chess_inator::search::is_quiescent_position(&board, eval);
let is_quiet = if is_quiet {"quiet"} else {"non-quiet"};
let is_quiet = if is_quiet { "quiet" } else { "non-quiet" };
let board_tensor = chess_inator::nnue::InputTensor::from_board(&board);
info.push(format!("NNUETrainInfo {} {}", is_quiet, {board_tensor}))
let abs_eval = EvalInt::from(eval) * EvalInt::from(board.get_turn().sign());
info.push(format!("NNUETrainInfo {} {} {}", is_quiet, abs_eval, {
board_tensor
}))
}
tx_main
.send(MsgToMain::Bestmove(MsgBestmove {
pv,
eval,
info,
}))
.send(MsgToMain::Bestmove(MsgBestmove { pv, eval, info }))
.unwrap();
}
MsgToEngine::Stop => {}

View File

@ -224,7 +224,7 @@ fn minmax(board: &mut Board, state: &mut EngineState, mm: MinmaxState) -> (Vec<M
}
}
let is_repetition_draw = board.history.count(board.zobrist) >= 2;
let is_repetition_draw = board.is_repetition();
let phase_factor = EvalInt::from(board.eval.min_maj_pieces / 5);
// positive here since we're looking from the opposite perspective.
// if white caused a draw, then we'd be black here.
@ -262,37 +262,60 @@ fn minmax(board: &mut Board, state: &mut EngineState, mm: MinmaxState) -> (Vec<M
}
}
// default to worst, then gradually improve
let mut alpha = mm.alpha.unwrap_or(EVAL_WORST);
// our best is their worst
let beta = mm.beta.unwrap_or(EVAL_BEST);
let mvs = if mm.quiesce {
board.gen_captures().into_iter().collect::<Vec<_>>()
enum MoveGenerator {
/// Use heavily pruned search to generate moves leading to a quiet position.
Quiescence,
/// Generate all legal moves.
Normal,
/// Only evaluate a single move.
None,
}
let mut move_generator = if mm.quiesce {
MoveGenerator::Quiescence
} else {
board.gen_moves().into_iter().collect::<Vec<_>>()
MoveGenerator::Normal
};
let mut trans_table_move: Option<Move> = None;
// get transposition table entry
if state.config.enable_trans_table {
if let Some(entry) = &state.cache[board.zobrist] {
trans_table_move = Some(entry.best_move);
if entry.is_qsearch == mm.quiesce && entry.depth >= mm.depth {
if let SearchEval::Exact(_) | SearchEval::Upper(_) = entry.eval {
// at this point, we could just return the best move + eval given, but this
// bypasses the draw by repetition checks in `minmax`. so just don't generate
// any other moves than the best move.
move_generator = MoveGenerator::None;
}
}
}
}
let mvs = match move_generator {
MoveGenerator::Quiescence => board.gen_captures().into_iter().collect::<Vec<_>>(),
MoveGenerator::Normal => board.gen_moves().into_iter().collect::<Vec<_>>(),
MoveGenerator::None => Vec::new(),
};
let mut mvs: Vec<_> = mvs
.into_iter()
.map(|mv| (move_priority(board, &mv, state), mv))
.collect();
// get transposition table entry
if state.config.enable_trans_table {
if let Some(entry) = &state.cache[board.zobrist] {
if entry.is_qsearch == mm.quiesce && entry.depth >= mm.depth {
if let SearchEval::Exact(_) | SearchEval::Upper(_) = entry.eval {
// no point looking for a better move
return (vec![entry.best_move], entry.eval);
}
}
mvs.push((EVAL_BEST, entry.best_move));
}
if let Some(trans_table_move) = trans_table_move {
mvs.push((EVAL_BEST, trans_table_move))
}
// sort moves by decreasing priority
mvs.sort_unstable_by_key(|mv| -mv.0);
// default to worst, then gradually improve
let mut alpha = mm.alpha.unwrap_or(EVAL_WORST);
// our best is their worst
let beta = mm.beta.unwrap_or(EVAL_BEST);
let mut abs_best = SearchEval::Exact(EVAL_WORST);
if mm.quiesce {
@ -340,6 +363,7 @@ fn minmax(board: &mut Board, state: &mut EngineState, mm: MinmaxState) -> (Vec<M
quiesce: mm.quiesce,
},
);
anti_mv.unmake(board);
// propagate hard stops
if matches!(score, SearchEval::Stopped) {
@ -353,7 +377,6 @@ fn minmax(board: &mut Board, state: &mut EngineState, mm: MinmaxState) -> (Vec<M
best_continuation = continuation;
}
alpha = max(alpha, abs_best.into());
anti_mv.unmake(board);
if alpha >= beta && state.config.alpha_beta_on {
// alpha-beta prune.
//
@ -457,7 +480,7 @@ impl TimeLimits {
// hard timeout (max)
let mut hard_ms = 100_000;
// soft timeout (default max)
let mut soft_ms = 1_200;
let mut soft_ms = 1_500;
// in some situations we can think longer
if eval.phase <= 13 {
@ -465,11 +488,11 @@ impl TimeLimits {
// opening
soft_ms = if ourtime_ms > 300_000 {
4_500
3_000
} else if ourtime_ms > 600_000 {
8_000
5_000
} else if ourtime_ms > 1_200_000 {
12_000
8_000
} else {
soft_ms
}
@ -558,7 +581,7 @@ pub fn best_move(board: &mut Board, engine_state: &mut EngineState) -> Option<Mo
/// It is the caller's responsibility to get the search evaluation and pass it to this function.
pub fn is_quiescent_position(board: &Board, eval: SearchEval) -> bool {
// max centipawn value difference to call "similar"
const THRESHOLD: EvalInt = 170;
const THRESHOLD: EvalInt = 120;
if board.is_check(board.turn) {
return false;
@ -573,3 +596,30 @@ pub fn is_quiescent_position(board: &Board, eval: SearchEval) -> bool {
(board.eval() - EvalInt::from(abs_eval)).abs() <= THRESHOLD.abs()
}
#[cfg(test)]
mod tests {
use super::*;
/// Test that running minmax does not alter the board.
#[test]
fn test_board_same() {
let (_tx, rx) = mpsc::channel();
let cache = TranspositionTable::new(1);
let mut engine_state = EngineState::new(
SearchConfig {
depth: 3,
..Default::default()
},
rx,
cache,
TimeLimits::from_movetime(20),
);
let mut board =
Board::from_fen("2rq1rk1/pp1bbppp/3p4/4p1B1/2B1P1n1/1PN5/P1PQ1PPP/R3K2R w KQ - 1 14")
.unwrap();
let orig_board = board;
let (_line, _eval) = best_line(&mut board, &mut engine_state);
assert_eq!(board, orig_board)
}
}