Compare commits
No commits in common. "0ae4c889dc57ee83a76ac59fc1493d77fe4c1c38" and "87501b5c94bb22add7e054531887a1457b117412" have entirely different histories.
0ae4c889dc
...
87501b5c94
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,4 +1,2 @@
|
|||||||
/target
|
/target
|
||||||
TODO.txt
|
TODO.txt
|
||||||
nnue/batches
|
|
||||||
nnue/venv
|
|
||||||
|
@ -59,13 +59,9 @@ mkdir -p games
|
|||||||
PGN=games/"$TAG1"__"$TAG2".pgn
|
PGN=games/"$TAG1"__"$TAG2".pgn
|
||||||
|
|
||||||
rm -f engine1 engine2
|
rm -f engine1 engine2
|
||||||
|
if [ -f "$PGN" ]; then
|
||||||
IDX=1
|
rm -i "$PGN"
|
||||||
while [ -e "$PGN" ]; do
|
fi
|
||||||
PGN=games/"$TAG1"__"$TAG2"__"$IDX".pgn
|
|
||||||
IDX=$(( $IDX + 1 ))
|
|
||||||
done
|
|
||||||
printf "using pgn output: %s\n" "$PGN" > /dev/stderr
|
|
||||||
|
|
||||||
git checkout "$TAG1"
|
git checkout "$TAG1"
|
||||||
cargo build --release
|
cargo build --release
|
||||||
|
@ -1,20 +0,0 @@
|
|||||||
# NNUE training tools
|
|
||||||
|
|
||||||
Python training pipeline for the evaluation neural network.
|
|
||||||
See the docstring in `src/nnue.rs` for information about the architecture of the NNUE.
|
|
||||||
The network is trained on both self-play games, and its games on Lichess.
|
|
||||||
Both of these sources provide games in PGN format.
|
|
||||||
|
|
||||||
This folder includes the following scripts:
|
|
||||||
- `batch_pgn_data.py`: Combine and convert big PGN files into small chunked files.
|
|
||||||
- `process_pgn_data.py`: Convert PGN data into a format suitable for training.
|
|
||||||
|
|
||||||
Example training pipeline:
|
|
||||||
```bash
|
|
||||||
# chunk all the PGN files in `games/`. outputs by default to `batches/batch%d.pgn`.
|
|
||||||
./batch_pgn_data.py games/*.pgn
|
|
||||||
|
|
||||||
# analyze batches 0 to 20 to turn them into training data. outputs by default to train_data/batch%d.tsv.gz.
|
|
||||||
# set max-workers to the number of hardware threads / cores you have.
|
|
||||||
./process_pgn_data.py --engine ../target/release/chess_inator --max-workers 8 batches/batch{0..20}.pgn
|
|
||||||
```
|
|
@ -1,50 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
"""
|
|
||||||
Batch PGN data into files, since the training data pipeline can't resume processing within a single file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# /// script
|
|
||||||
# requires-python = ">=3.12"
|
|
||||||
# dependencies = [
|
|
||||||
# "chess",
|
|
||||||
# ]
|
|
||||||
# ///
|
|
||||||
|
|
||||||
from typing import Iterator
|
|
||||||
import chess.pgn
|
|
||||||
import argparse
|
|
||||||
import itertools
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
"""Games to include per file in output."""
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("files", nargs="+", type=Path)
|
|
||||||
parser.add_argument("--batch-size", type=int, help="Number of games to save in each output file.", default=8)
|
|
||||||
parser.add_argument("--output-folder", type=Path, help="Folder to save batched games in.", default=Path("batches"))
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
def generate_games_in_file(path: Path) -> Iterator[chess.pgn.Game]:
|
|
||||||
"""Read games from a single PGN file."""
|
|
||||||
with open(path) as f:
|
|
||||||
while game := chess.pgn.read_game(f):
|
|
||||||
game.headers["PGNPath"] = str(path)
|
|
||||||
yield game
|
|
||||||
|
|
||||||
def generate_games() -> Iterator[chess.pgn.Game]:
|
|
||||||
"""Read games from all files."""
|
|
||||||
for path in args.files:
|
|
||||||
yield from generate_games_in_file(path)
|
|
||||||
|
|
||||||
def batch_games():
|
|
||||||
"""Write games in batches."""
|
|
||||||
output_folder: Path = args.output_folder
|
|
||||||
output_folder.mkdir(exist_ok=True)
|
|
||||||
for idx, batch in enumerate(itertools.batched(generate_games(), args.batch_size)):
|
|
||||||
with (output_folder / f"batch{idx}.pgn").open("w") as f:
|
|
||||||
for game in batch:
|
|
||||||
f.write(str(game) + "\n\n")
|
|
||||||
|
|
||||||
batch_games()
|
|
@ -1,245 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
"""
|
|
||||||
Processes PGN game data into a tsv format suitable for training.
|
|
||||||
|
|
||||||
Output columns:
|
|
||||||
- FEN (for reference)
|
|
||||||
- ALL 768-bit binary string representing the position
|
|
||||||
- Evaluation (centipawns) from white perspective
|
|
||||||
- Result of the game (-1, 0, 1)
|
|
||||||
|
|
||||||
This script depends on the `chess` package.
|
|
||||||
Install it, or run this script using `pipx run process_pgn_data.py`.
|
|
||||||
The script also depends on the chess_inator engine for analysis and filtering.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# /// script
|
|
||||||
# requires-python = ">=3.11"
|
|
||||||
# dependencies = [
|
|
||||||
# "chess",
|
|
||||||
# ]
|
|
||||||
# ///
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from asyncio import Queue, TaskGroup, create_task, run, sleep
|
|
||||||
import logging
|
|
||||||
import datetime
|
|
||||||
import multiprocessing
|
|
||||||
import gzip
|
|
||||||
import csv
|
|
||||||
|
|
||||||
import chess
|
|
||||||
import chess.engine
|
|
||||||
from typing import AsyncIterator, Literal
|
|
||||||
from chess import pgn
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--log",
|
|
||||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
|
||||||
default="INFO",
|
|
||||||
help="Sets log level.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--engine",
|
|
||||||
help="Set the file path of the chess_inator engine used to analyze the positions.",
|
|
||||||
type=Path,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-workers",
|
|
||||||
help="Max concurrent workers to analyse games with (limit this to your hardware thread count).",
|
|
||||||
default=min(4, multiprocessing.cpu_count()),
|
|
||||||
type=int,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--preserve-partial",
|
|
||||||
action="store_true",
|
|
||||||
help="Keep output files that have not been fully written. These files may confuse this script when resuming operations.",
|
|
||||||
)
|
|
||||||
parser.add_argument("files", nargs="+", type=Path)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=getattr(logging, str.upper(args.log)))
|
|
||||||
|
|
||||||
|
|
||||||
"""Skip these many plies from the start (avoid training on opening)."""
|
|
||||||
SKIP_PLIES: int = 20
|
|
||||||
|
|
||||||
"""Time limit in seconds for each position to be analyzed."""
|
|
||||||
TIME_LIMIT: float = 3
|
|
||||||
|
|
||||||
|
|
||||||
output_queue: Queue[tuple[str, str, int, Literal[-1, 0, 1]]] = Queue()
|
|
||||||
|
|
||||||
|
|
||||||
# stats for progress
|
|
||||||
completed = 0
|
|
||||||
discarded = 0
|
|
||||||
current_outp: Path | None = None
|
|
||||||
start_time = datetime.datetime.now()
|
|
||||||
|
|
||||||
|
|
||||||
async def load_games(file: Path):
|
|
||||||
"""Load a PGN file and divide up the games for the workers to process."""
|
|
||||||
with open(file) as f:
|
|
||||||
while game := pgn.read_game(f):
|
|
||||||
yield game
|
|
||||||
|
|
||||||
|
|
||||||
async def worker(game_generator: AsyncIterator[pgn.Game]) -> None:
|
|
||||||
"""
|
|
||||||
Single worker that analyzes whole games.
|
|
||||||
|
|
||||||
Code pattern taken from https://stackoverflow.com/a/54975674.
|
|
||||||
|
|
||||||
Puts rows of output into a global queue.
|
|
||||||
"""
|
|
||||||
transport, engine = await chess.engine.popen_uci(args.engine)
|
|
||||||
await engine.configure(dict(NNUETrainInfo="true"))
|
|
||||||
|
|
||||||
async for game in game_generator:
|
|
||||||
wdl: int | None = None
|
|
||||||
|
|
||||||
match game.headers["Result"]:
|
|
||||||
case "1-0":
|
|
||||||
wdl = 1
|
|
||||||
case "0-1":
|
|
||||||
wdl = -1
|
|
||||||
case "1/2-1/2":
|
|
||||||
wdl = 0
|
|
||||||
case other_result:
|
|
||||||
logging.error("invalid 'Result' header: '%s'", other_result)
|
|
||||||
continue
|
|
||||||
|
|
||||||
board = game.board()
|
|
||||||
|
|
||||||
skipped = 0
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
"Processing game %s, %s (%s) between %s as White and %s as Black.",
|
|
||||||
game.headers["Event"],
|
|
||||||
game.headers["Site"],
|
|
||||||
game.headers["Date"],
|
|
||||||
game.headers["White"],
|
|
||||||
game.headers["Black"],
|
|
||||||
)
|
|
||||||
|
|
||||||
for move in game.mainline_moves():
|
|
||||||
board.push(move)
|
|
||||||
if skipped < SKIP_PLIES:
|
|
||||||
skipped += 1
|
|
||||||
continue
|
|
||||||
result = await engine.play(
|
|
||||||
board,
|
|
||||||
chess.engine.Limit(time=TIME_LIMIT),
|
|
||||||
info=chess.engine.INFO_ALL,
|
|
||||||
game=game,
|
|
||||||
)
|
|
||||||
|
|
||||||
info_str = result.info.get("string")
|
|
||||||
if not info_str:
|
|
||||||
raise RuntimeError("Could not analyze position with engine.")
|
|
||||||
(name, quiet, eval_abs, tensor) = info_str.split()
|
|
||||||
if not name == "NNUETrainInfo":
|
|
||||||
raise RuntimeError(f"Unexpected output from engine: {info_str}")
|
|
||||||
|
|
||||||
if quiet == "non-quiet":
|
|
||||||
global discarded
|
|
||||||
discarded += 1
|
|
||||||
logging.debug("discarded as non-quiet: '%s'", board.fen())
|
|
||||||
continue
|
|
||||||
elif quiet != "quiet":
|
|
||||||
raise RuntimeError(f"Unexpected output from engine: {info_str}")
|
|
||||||
|
|
||||||
await output_queue.put((board.fen(), tensor, int(eval_abs), wdl))
|
|
||||||
|
|
||||||
|
|
||||||
async def analyse_games(file: Path):
|
|
||||||
"""Task that manages reading PGNs and analyzing them."""
|
|
||||||
games_generator = load_games(file)
|
|
||||||
|
|
||||||
async with TaskGroup() as tg:
|
|
||||||
worker_count: int = min(args.max_workers, multiprocessing.cpu_count())
|
|
||||||
logging.info("Using %d concurrent worker tasks.", worker_count)
|
|
||||||
for i in range(worker_count):
|
|
||||||
tg.create_task(worker(games_generator))
|
|
||||||
|
|
||||||
|
|
||||||
async def output_rows(outp_file: Path):
|
|
||||||
"""TSV writer task."""
|
|
||||||
|
|
||||||
with gzip.open(outp_file, "wt") as f:
|
|
||||||
writer = csv.writer(f, delimiter="\t")
|
|
||||||
while True:
|
|
||||||
row = await output_queue.get()
|
|
||||||
writer.writerow(row)
|
|
||||||
output_queue.task_done()
|
|
||||||
global completed
|
|
||||||
completed += 1
|
|
||||||
|
|
||||||
|
|
||||||
async def status_logger():
|
|
||||||
"""Periodically print status."""
|
|
||||||
while True:
|
|
||||||
await sleep(5)
|
|
||||||
logging.info(
|
|
||||||
"Completed %d rows in %f seconds. Discarded %d non-quiet positions.",
|
|
||||||
completed,
|
|
||||||
(datetime.datetime.now() - start_time).total_seconds(),
|
|
||||||
discarded,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
status_task = create_task(status_logger())
|
|
||||||
|
|
||||||
outp_dir = Path("train_data")
|
|
||||||
outp_dir.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
any_file = False
|
|
||||||
skipped = False
|
|
||||||
|
|
||||||
for file in args.files:
|
|
||||||
file: Path
|
|
||||||
|
|
||||||
outp_file = outp_dir / file.with_suffix(".tsv.gz").name
|
|
||||||
|
|
||||||
if outp_file.exists():
|
|
||||||
skipped = True
|
|
||||||
continue
|
|
||||||
|
|
||||||
any_file = True
|
|
||||||
|
|
||||||
if skipped:
|
|
||||||
logging.info("Resuming at file '%s'.", file)
|
|
||||||
else:
|
|
||||||
logging.info("Reading file '%s'.", file)
|
|
||||||
|
|
||||||
global current_outp
|
|
||||||
current_outp = outp_file
|
|
||||||
|
|
||||||
output_task = create_task(output_rows(outp_file))
|
|
||||||
analyse_task = create_task(analyse_games(file))
|
|
||||||
await analyse_task
|
|
||||||
output_task.cancel()
|
|
||||||
|
|
||||||
if not any_file:
|
|
||||||
logging.warning("Nothing to do. All input files have outputs already.")
|
|
||||||
|
|
||||||
status_task.cancel()
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
run(main())
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logging.critical("shutting down.")
|
|
||||||
if current_outp and not args.preserve_partial:
|
|
||||||
logging.critical("discarding partial output file %s", current_outp)
|
|
||||||
current_outp.unlink()
|
|
49
src/lib.rs
49
src/lib.rs
@ -19,10 +19,10 @@ use std::str::FromStr;
|
|||||||
|
|
||||||
pub mod coordination;
|
pub mod coordination;
|
||||||
pub mod eval;
|
pub mod eval;
|
||||||
|
pub mod nnue;
|
||||||
pub mod fen;
|
pub mod fen;
|
||||||
mod hash;
|
mod hash;
|
||||||
pub mod movegen;
|
pub mod movegen;
|
||||||
pub mod nnue;
|
|
||||||
pub mod random;
|
pub mod random;
|
||||||
pub mod search;
|
pub mod search;
|
||||||
|
|
||||||
@ -554,7 +554,7 @@ const HISTORY_SIZE: usize = 100;
|
|||||||
|
|
||||||
impl BoardHistory {
|
impl BoardHistory {
|
||||||
/// Counts occurences of this hash in the history.
|
/// Counts occurences of this hash in the history.
|
||||||
fn _count(&self, hash: Zobrist) -> usize {
|
fn count(&self, hash: Zobrist) -> usize {
|
||||||
let mut ans = 0;
|
let mut ans = 0;
|
||||||
|
|
||||||
let mut i = self.ptr_start;
|
let mut i = self.ptr_start;
|
||||||
@ -568,20 +568,6 @@ impl BoardHistory {
|
|||||||
ans
|
ans
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Find if there are at least `n` matches for a hash in the last `recent` plies.
|
|
||||||
fn at_least_in_recent(&self, mut n: usize, recent: usize, hash: Zobrist) -> bool {
|
|
||||||
let mut i = self.ptr_end - recent;
|
|
||||||
|
|
||||||
while i != self.ptr_end && n > 0 {
|
|
||||||
if self.hashes[usize::from(i)] == hash {
|
|
||||||
n -= 1;
|
|
||||||
}
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
n == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add (push) hash to history.
|
/// Add (push) hash to history.
|
||||||
fn push(&mut self, hash: Zobrist) {
|
fn push(&mut self, hash: Zobrist) {
|
||||||
self.hashes[usize::from(self.ptr_end)] = hash;
|
self.hashes[usize::from(self.ptr_end)] = hash;
|
||||||
@ -647,12 +633,6 @@ impl Board {
|
|||||||
self.history.push(self.zobrist);
|
self.history.push(self.zobrist);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Is this position a draw by three repetitions?
|
|
||||||
pub fn is_repetition(&mut self) -> bool {
|
|
||||||
self.history
|
|
||||||
.at_least_in_recent(2, self.half_moves, self.zobrist)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get iterator over all squares.
|
/// Get iterator over all squares.
|
||||||
pub fn squares() -> impl Iterator<Item = Square> {
|
pub fn squares() -> impl Iterator<Item = Square> {
|
||||||
(0..N_SQUARES).map(Square::try_from).map(|x| x.unwrap())
|
(0..N_SQUARES).map(Square::try_from).map(|x| x.unwrap())
|
||||||
@ -968,30 +948,20 @@ mod tests {
|
|||||||
history.push(board.zobrist);
|
history.push(board.zobrist);
|
||||||
}
|
}
|
||||||
|
|
||||||
assert_eq!(history._count(board.zobrist), HISTORY_SIZE - 1);
|
assert_eq!(history.count(board.zobrist), HISTORY_SIZE - 1);
|
||||||
assert!(history.at_least_in_recent(1, 1, board.zobrist));
|
|
||||||
assert!(history.at_least_in_recent(2, 3, board.zobrist));
|
|
||||||
assert!(history.at_least_in_recent(1, 3, board.zobrist));
|
|
||||||
|
|
||||||
let board_empty = Board::default();
|
let board_empty = Board::default();
|
||||||
history.push(board_empty.zobrist);
|
history.push(board_empty.zobrist);
|
||||||
|
|
||||||
assert!(!history.at_least_in_recent(1, 1, board.zobrist));
|
assert_eq!(history.count(board.zobrist), HISTORY_SIZE - 2);
|
||||||
assert!(history.at_least_in_recent(1, 2, board.zobrist));
|
assert_eq!(history.count(board_empty.zobrist), 1);
|
||||||
|
|
||||||
assert_eq!(history._count(board.zobrist), HISTORY_SIZE - 2);
|
|
||||||
assert_eq!(history._count(board_empty.zobrist), 1);
|
|
||||||
assert!(history.at_least_in_recent(1, 3, board.zobrist));
|
|
||||||
assert!(history.at_least_in_recent(1, 20, board_empty.zobrist));
|
|
||||||
assert!(history.at_least_in_recent(1, 15, board_empty.zobrist));
|
|
||||||
assert!(history.at_least_in_recent(1, 1, board_empty.zobrist));
|
|
||||||
|
|
||||||
for _ in 0..3 {
|
for _ in 0..3 {
|
||||||
history.push(board_empty.zobrist);
|
history.push(board_empty.zobrist);
|
||||||
}
|
}
|
||||||
|
|
||||||
assert_eq!(history._count(board_empty.zobrist), 4);
|
assert_eq!(history.count(board_empty.zobrist), 4);
|
||||||
assert_eq!(history._count(board.zobrist), HISTORY_SIZE - 5);
|
assert_eq!(history.count(board.zobrist), HISTORY_SIZE - 5);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -1012,8 +982,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// this is the third occurence, but beforehand there are two occurences
|
// this is the third occurence, but beforehand there are two occurences
|
||||||
assert_eq!(board.history._count(board.zobrist), 2);
|
assert_eq!(board.history.count(board.zobrist), 2);
|
||||||
assert!(board.is_repetition(), "fen: {}", board.to_fen());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// engine should take advantage of the three time repetition rule
|
/// engine should take advantage of the three time repetition rule
|
||||||
@ -1063,7 +1032,7 @@ mod tests {
|
|||||||
expected_bestmv.make(&mut board);
|
expected_bestmv.make(&mut board);
|
||||||
eprintln!(
|
eprintln!(
|
||||||
"after expected mv, board repeated {} times",
|
"after expected mv, board repeated {} times",
|
||||||
board.history._count(board.zobrist)
|
board.history.count(board.zobrist)
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
13
src/main.rs
13
src/main.rs
@ -370,18 +370,19 @@ fn task_engine(tx_main: Sender<MsgToMain>, rx_engine: Receiver<MsgToEngine>) {
|
|||||||
let mut info: Vec<String> = Vec::new();
|
let mut info: Vec<String> = Vec::new();
|
||||||
if state.config.nnue_train_info {
|
if state.config.nnue_train_info {
|
||||||
let is_quiet = chess_inator::search::is_quiescent_position(&board, eval);
|
let is_quiet = chess_inator::search::is_quiescent_position(&board, eval);
|
||||||
let is_quiet = if is_quiet { "quiet" } else { "non-quiet" };
|
let is_quiet = if is_quiet {"quiet"} else {"non-quiet"};
|
||||||
|
|
||||||
let board_tensor = chess_inator::nnue::InputTensor::from_board(&board);
|
let board_tensor = chess_inator::nnue::InputTensor::from_board(&board);
|
||||||
|
|
||||||
let abs_eval = EvalInt::from(eval) * EvalInt::from(board.get_turn().sign());
|
info.push(format!("NNUETrainInfo {} {}", is_quiet, {board_tensor}))
|
||||||
info.push(format!("NNUETrainInfo {} {} {}", is_quiet, abs_eval, {
|
|
||||||
board_tensor
|
|
||||||
}))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tx_main
|
tx_main
|
||||||
.send(MsgToMain::Bestmove(MsgBestmove { pv, eval, info }))
|
.send(MsgToMain::Bestmove(MsgBestmove {
|
||||||
|
pv,
|
||||||
|
eval,
|
||||||
|
info,
|
||||||
|
}))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
MsgToEngine::Stop => {}
|
MsgToEngine::Stop => {}
|
||||||
|
102
src/search.rs
102
src/search.rs
@ -224,7 +224,7 @@ fn minmax(board: &mut Board, state: &mut EngineState, mm: MinmaxState) -> (Vec<M
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let is_repetition_draw = board.is_repetition();
|
let is_repetition_draw = board.history.count(board.zobrist) >= 2;
|
||||||
let phase_factor = EvalInt::from(board.eval.min_maj_pieces / 5);
|
let phase_factor = EvalInt::from(board.eval.min_maj_pieces / 5);
|
||||||
// positive here since we're looking from the opposite perspective.
|
// positive here since we're looking from the opposite perspective.
|
||||||
// if white caused a draw, then we'd be black here.
|
// if white caused a draw, then we'd be black here.
|
||||||
@ -262,60 +262,37 @@ fn minmax(board: &mut Board, state: &mut EngineState, mm: MinmaxState) -> (Vec<M
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
enum MoveGenerator {
|
// default to worst, then gradually improve
|
||||||
/// Use heavily pruned search to generate moves leading to a quiet position.
|
let mut alpha = mm.alpha.unwrap_or(EVAL_WORST);
|
||||||
Quiescence,
|
// our best is their worst
|
||||||
/// Generate all legal moves.
|
let beta = mm.beta.unwrap_or(EVAL_BEST);
|
||||||
Normal,
|
|
||||||
/// Only evaluate a single move.
|
let mvs = if mm.quiesce {
|
||||||
None,
|
board.gen_captures().into_iter().collect::<Vec<_>>()
|
||||||
}
|
|
||||||
let mut move_generator = if mm.quiesce {
|
|
||||||
MoveGenerator::Quiescence
|
|
||||||
} else {
|
} else {
|
||||||
MoveGenerator::Normal
|
board.gen_moves().into_iter().collect::<Vec<_>>()
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut trans_table_move: Option<Move> = None;
|
|
||||||
|
|
||||||
// get transposition table entry
|
|
||||||
if state.config.enable_trans_table {
|
|
||||||
if let Some(entry) = &state.cache[board.zobrist] {
|
|
||||||
trans_table_move = Some(entry.best_move);
|
|
||||||
if entry.is_qsearch == mm.quiesce && entry.depth >= mm.depth {
|
|
||||||
if let SearchEval::Exact(_) | SearchEval::Upper(_) = entry.eval {
|
|
||||||
// at this point, we could just return the best move + eval given, but this
|
|
||||||
// bypasses the draw by repetition checks in `minmax`. so just don't generate
|
|
||||||
// any other moves than the best move.
|
|
||||||
move_generator = MoveGenerator::None;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mvs = match move_generator {
|
|
||||||
MoveGenerator::Quiescence => board.gen_captures().into_iter().collect::<Vec<_>>(),
|
|
||||||
MoveGenerator::Normal => board.gen_moves().into_iter().collect::<Vec<_>>(),
|
|
||||||
MoveGenerator::None => Vec::new(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut mvs: Vec<_> = mvs
|
let mut mvs: Vec<_> = mvs
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|mv| (move_priority(board, &mv, state), mv))
|
.map(|mv| (move_priority(board, &mv, state), mv))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
if let Some(trans_table_move) = trans_table_move {
|
// get transposition table entry
|
||||||
mvs.push((EVAL_BEST, trans_table_move))
|
if state.config.enable_trans_table {
|
||||||
|
if let Some(entry) = &state.cache[board.zobrist] {
|
||||||
|
if entry.is_qsearch == mm.quiesce && entry.depth >= mm.depth {
|
||||||
|
if let SearchEval::Exact(_) | SearchEval::Upper(_) = entry.eval {
|
||||||
|
// no point looking for a better move
|
||||||
|
return (vec![entry.best_move], entry.eval);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mvs.push((EVAL_BEST, entry.best_move));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// sort moves by decreasing priority
|
// sort moves by decreasing priority
|
||||||
mvs.sort_unstable_by_key(|mv| -mv.0);
|
mvs.sort_unstable_by_key(|mv| -mv.0);
|
||||||
|
|
||||||
// default to worst, then gradually improve
|
|
||||||
let mut alpha = mm.alpha.unwrap_or(EVAL_WORST);
|
|
||||||
// our best is their worst
|
|
||||||
let beta = mm.beta.unwrap_or(EVAL_BEST);
|
|
||||||
|
|
||||||
let mut abs_best = SearchEval::Exact(EVAL_WORST);
|
let mut abs_best = SearchEval::Exact(EVAL_WORST);
|
||||||
|
|
||||||
if mm.quiesce {
|
if mm.quiesce {
|
||||||
@ -363,7 +340,6 @@ fn minmax(board: &mut Board, state: &mut EngineState, mm: MinmaxState) -> (Vec<M
|
|||||||
quiesce: mm.quiesce,
|
quiesce: mm.quiesce,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
anti_mv.unmake(board);
|
|
||||||
|
|
||||||
// propagate hard stops
|
// propagate hard stops
|
||||||
if matches!(score, SearchEval::Stopped) {
|
if matches!(score, SearchEval::Stopped) {
|
||||||
@ -377,6 +353,7 @@ fn minmax(board: &mut Board, state: &mut EngineState, mm: MinmaxState) -> (Vec<M
|
|||||||
best_continuation = continuation;
|
best_continuation = continuation;
|
||||||
}
|
}
|
||||||
alpha = max(alpha, abs_best.into());
|
alpha = max(alpha, abs_best.into());
|
||||||
|
anti_mv.unmake(board);
|
||||||
if alpha >= beta && state.config.alpha_beta_on {
|
if alpha >= beta && state.config.alpha_beta_on {
|
||||||
// alpha-beta prune.
|
// alpha-beta prune.
|
||||||
//
|
//
|
||||||
@ -480,7 +457,7 @@ impl TimeLimits {
|
|||||||
// hard timeout (max)
|
// hard timeout (max)
|
||||||
let mut hard_ms = 100_000;
|
let mut hard_ms = 100_000;
|
||||||
// soft timeout (default max)
|
// soft timeout (default max)
|
||||||
let mut soft_ms = 1_500;
|
let mut soft_ms = 1_200;
|
||||||
|
|
||||||
// in some situations we can think longer
|
// in some situations we can think longer
|
||||||
if eval.phase <= 13 {
|
if eval.phase <= 13 {
|
||||||
@ -488,11 +465,11 @@ impl TimeLimits {
|
|||||||
// opening
|
// opening
|
||||||
|
|
||||||
soft_ms = if ourtime_ms > 300_000 {
|
soft_ms = if ourtime_ms > 300_000 {
|
||||||
3_000
|
4_500
|
||||||
} else if ourtime_ms > 600_000 {
|
} else if ourtime_ms > 600_000 {
|
||||||
5_000
|
|
||||||
} else if ourtime_ms > 1_200_000 {
|
|
||||||
8_000
|
8_000
|
||||||
|
} else if ourtime_ms > 1_200_000 {
|
||||||
|
12_000
|
||||||
} else {
|
} else {
|
||||||
soft_ms
|
soft_ms
|
||||||
}
|
}
|
||||||
@ -581,7 +558,7 @@ pub fn best_move(board: &mut Board, engine_state: &mut EngineState) -> Option<Mo
|
|||||||
/// It is the caller's responsibility to get the search evaluation and pass it to this function.
|
/// It is the caller's responsibility to get the search evaluation and pass it to this function.
|
||||||
pub fn is_quiescent_position(board: &Board, eval: SearchEval) -> bool {
|
pub fn is_quiescent_position(board: &Board, eval: SearchEval) -> bool {
|
||||||
// max centipawn value difference to call "similar"
|
// max centipawn value difference to call "similar"
|
||||||
const THRESHOLD: EvalInt = 120;
|
const THRESHOLD: EvalInt = 170;
|
||||||
|
|
||||||
if board.is_check(board.turn) {
|
if board.is_check(board.turn) {
|
||||||
return false;
|
return false;
|
||||||
@ -596,30 +573,3 @@ pub fn is_quiescent_position(board: &Board, eval: SearchEval) -> bool {
|
|||||||
|
|
||||||
(board.eval() - EvalInt::from(abs_eval)).abs() <= THRESHOLD.abs()
|
(board.eval() - EvalInt::from(abs_eval)).abs() <= THRESHOLD.abs()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
/// Test that running minmax does not alter the board.
|
|
||||||
#[test]
|
|
||||||
fn test_board_same() {
|
|
||||||
let (_tx, rx) = mpsc::channel();
|
|
||||||
let cache = TranspositionTable::new(1);
|
|
||||||
let mut engine_state = EngineState::new(
|
|
||||||
SearchConfig {
|
|
||||||
depth: 3,
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
rx,
|
|
||||||
cache,
|
|
||||||
TimeLimits::from_movetime(20),
|
|
||||||
);
|
|
||||||
let mut board =
|
|
||||||
Board::from_fen("2rq1rk1/pp1bbppp/3p4/4p1B1/2B1P1n1/1PN5/P1PQ1PPP/R3K2R w KQ - 1 14")
|
|
||||||
.unwrap();
|
|
||||||
let orig_board = board;
|
|
||||||
let (_line, _eval) = best_line(&mut board, &mut engine_state);
|
|
||||||
assert_eq!(board, orig_board)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user