feat: soft/hard time limit
achieved by refactoring engine/main/stdin into three separate threads.
This commit is contained in:
parent
6be00e642e
commit
e44cc0586e
191
src/coordination.rs
Normal file
191
src/coordination.rs
Normal file
@ -0,0 +1,191 @@
|
||||
/*
|
||||
|
||||
This file is part of chess_inator.
|
||||
chess_inator is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 3 of the License, or (at your option) any later version.
|
||||
|
||||
chess_inator is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License along with chess_inator. If not, see https://www.gnu.org/licenses/.
|
||||
|
||||
Copyright © 2024 dogeystamp <dogeystamp@disroot.org>
|
||||
*/
|
||||
|
||||
//! Threading, state, and flow of information management.
|
||||
//!
|
||||
//! This file contains types and helper utilities; see main for actual implementation.
|
||||
|
||||
use crate::prelude::*;
|
||||
|
||||
/// State machine states.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum UCIMode {
|
||||
/// It is engine's turn; engine is thinking about a move.
|
||||
Think,
|
||||
/// It is the opponent's turn; engine is thinking about a move.
|
||||
Ponder,
|
||||
/// The engine is not doing anything.
|
||||
Idle,
|
||||
}
|
||||
|
||||
/// State machine transitions.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum UCIModeTransition {
|
||||
/// Engine produces a best move result. Thinking to Idle.
|
||||
Bestmove,
|
||||
/// Engine is stopped via a UCI `stop` command. Thinking/Ponder to Idle.
|
||||
Stop,
|
||||
/// Engine is asked for a best move through a UCI `go`. Idle -> Thinking.
|
||||
Go,
|
||||
/// Engine starts pondering on the opponent's time. Idle -> Ponder.
|
||||
GoPonder,
|
||||
/// While engine ponders, the opponent plays a different move than expected. Ponder -> Thinking
|
||||
///
|
||||
/// In UCI, this means that a new `position` command is sent.
|
||||
PonderMiss,
|
||||
/// While engine ponders, the opponent plays the expected move (`ponderhit`). Ponder -> Thinking
|
||||
PonderHit,
|
||||
}
|
||||
|
||||
impl UCIModeTransition {
|
||||
/// The state that a transition goes to.
|
||||
const fn dest_mode(&self) -> UCIMode {
|
||||
use UCIMode::*;
|
||||
use UCIModeTransition::*;
|
||||
match self {
|
||||
Bestmove => Idle,
|
||||
Stop => Idle,
|
||||
Go => Think,
|
||||
GoPonder => Ponder,
|
||||
PonderMiss => Think,
|
||||
PonderHit => Think,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// State machine for engine's UCI modes.
|
||||
#[derive(Debug)]
|
||||
pub struct UCIModeMachine {
|
||||
pub mode: UCIMode,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct InvalidTransitionError {
|
||||
/// Original state.
|
||||
pub from: UCIMode,
|
||||
/// Desired destination state.
|
||||
pub to: UCIMode,
|
||||
}
|
||||
|
||||
impl Default for UCIModeMachine {
|
||||
fn default() -> Self {
|
||||
UCIModeMachine {
|
||||
mode: UCIMode::Idle,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UCIModeMachine {
|
||||
/// Change state (checked to prevent invalid transitions.)
|
||||
pub fn transition(&mut self, t: UCIModeTransition) -> Result<(), InvalidTransitionError> {
|
||||
macro_rules! illegal {
|
||||
() => {
|
||||
return Err(InvalidTransitionError {
|
||||
from: self.mode,
|
||||
to: t.dest_mode(),
|
||||
})
|
||||
};
|
||||
}
|
||||
macro_rules! legal {
|
||||
() => {{
|
||||
self.mode = t.dest_mode();
|
||||
return Ok(());
|
||||
}};
|
||||
}
|
||||
|
||||
use UCIModeTransition::*;
|
||||
|
||||
match t {
|
||||
Bestmove => match self.mode {
|
||||
UCIMode::Think => legal!(),
|
||||
_ => illegal!(),
|
||||
},
|
||||
Stop => match self.mode {
|
||||
UCIMode::Ponder | UCIMode::Think => legal!(),
|
||||
_ => illegal!(),
|
||||
},
|
||||
Go | GoPonder => match self.mode {
|
||||
UCIMode::Idle => legal!(),
|
||||
_ => illegal!(),
|
||||
},
|
||||
PonderMiss => match self.mode {
|
||||
UCIMode::Ponder => legal!(),
|
||||
_ => illegal!(),
|
||||
},
|
||||
PonderHit => match self.mode {
|
||||
UCIMode::Ponder => legal!(),
|
||||
_ => illegal!(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_state_machine {
|
||||
use super::*;
|
||||
|
||||
/// Non-exhaustive test of state machine.
|
||||
#[test]
|
||||
fn test_transitions() {
|
||||
let mut machine = UCIModeMachine {
|
||||
mode: UCIMode::Idle,
|
||||
};
|
||||
assert!(matches!(machine.transition(UCIModeTransition::Go), Ok(())));
|
||||
assert!(matches!(machine.mode, UCIMode::Think));
|
||||
assert!(matches!(
|
||||
machine.transition(UCIModeTransition::Stop),
|
||||
Ok(())
|
||||
));
|
||||
assert!(matches!(machine.mode, UCIMode::Idle));
|
||||
assert!(matches!(machine.transition(UCIModeTransition::Go), Ok(())));
|
||||
assert!(matches!(
|
||||
machine.transition(UCIModeTransition::Bestmove),
|
||||
Ok(())
|
||||
));
|
||||
assert!(matches!(machine.mode, UCIMode::Idle));
|
||||
assert!(matches!(
|
||||
machine.transition(UCIModeTransition::Bestmove),
|
||||
Err(_)
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
/// Message (engine->main) to communicate the best move.
|
||||
pub struct MsgBestmove {
|
||||
/// Best line (reversed stack; last element is best current move)
|
||||
pub pv: Vec<Move>,
|
||||
/// Evaluation of the position
|
||||
pub eval: SearchEval,
|
||||
}
|
||||
|
||||
/// Interface messages that may be received by main's channel.
|
||||
pub enum MsgToMain {
|
||||
StdinLine(String),
|
||||
Bestmove(MsgBestmove),
|
||||
}
|
||||
|
||||
pub struct GoMessage {
|
||||
pub board: Board,
|
||||
pub config: SearchConfig,
|
||||
pub time_lims: TimeLimits,
|
||||
}
|
||||
|
||||
/// Main -> Engine thread channel message.
|
||||
pub enum MsgToEngine {
|
||||
/// `go` command. Also sends board position and engine configuration to avoid state
|
||||
/// synchronization issues (i.e. avoid sending position after a go command, and not before).
|
||||
Go(Box<GoMessage>),
|
||||
/// Hard stop command. Halt search immediately.
|
||||
Stop,
|
||||
/// Ask the engine to wipe its state (notably transposition table).
|
||||
NewGame,
|
||||
}
|
@ -17,6 +17,7 @@ use std::fmt::Display;
|
||||
use std::ops::{Index, IndexMut};
|
||||
use std::str::FromStr;
|
||||
|
||||
pub mod coordination;
|
||||
pub mod eval;
|
||||
pub mod fen;
|
||||
mod hash;
|
||||
@ -432,7 +433,7 @@ impl Display for CastleRights {
|
||||
}
|
||||
}
|
||||
|
||||
/// Immutable game state, unique to a position.
|
||||
/// Game state, describes a position.
|
||||
///
|
||||
/// Default is empty.
|
||||
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
|
||||
|
304
src/main.rs
304
src/main.rs
@ -11,14 +11,33 @@ Copyright © 2024 dogeystamp <dogeystamp@disroot.org>
|
||||
*/
|
||||
|
||||
//! Main UCI engine binary.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! This runs three threads, Main, Engine, and Stdin. Main coordinates everything, and performs UCI
|
||||
//! parsing/communication. Stdin is read on a different thread, in order to avoid blocking on it.
|
||||
//! The Engine is where the actual computation happens. It communicates state (best move, evaluations,
|
||||
//! board state and configuration) with Main.
|
||||
//!
|
||||
//! Main has a single rx (receive) channel. This is so that it can wait for either the Engine to
|
||||
//! finish a computation, or for Stdin to receive a UCI command. This way, the overall engine
|
||||
//! program is always listening, even when it is thinking.
|
||||
//!
|
||||
//! For every go command, Main sends data, notably the current position and engine configuration,
|
||||
//! to the Engine. The current position and config are re-sent every time because Main is where the
|
||||
//! opponent's move, as well as any configuration options, are read and parsed. Meanwhile, internal
|
||||
//! data, like the transposition table, is owned by the Engine thread.
|
||||
//!
|
||||
//! # Notes
|
||||
//!
|
||||
//! - The naming scheme for channels here is `tx_main`, `rx_main` for "transmit to Main" and
|
||||
//! "receive at Main" respectively. These names would be used for one channel.
|
||||
|
||||
use chess_inator::prelude::*;
|
||||
use std::cmp::min;
|
||||
use std::io;
|
||||
use std::sync::mpsc::channel;
|
||||
use std::process::exit;
|
||||
use std::sync::mpsc::{channel, Receiver, Sender};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
|
||||
/// UCI protocol says to ignore any unknown words.
|
||||
///
|
||||
@ -55,7 +74,7 @@ fn cmd_position_moves(mut tokens: std::str::SplitWhitespace<'_>, mut board: Boar
|
||||
}
|
||||
|
||||
/// Sets the position.
|
||||
fn cmd_position(mut tokens: std::str::SplitWhitespace<'_>) -> Board {
|
||||
fn cmd_position(mut tokens: std::str::SplitWhitespace<'_>, state: &mut MainState) {
|
||||
while let Some(token) = tokens.next() {
|
||||
match token {
|
||||
"fen" => {
|
||||
@ -72,80 +91,138 @@ fn cmd_position(mut tokens: std::str::SplitWhitespace<'_>) -> Board {
|
||||
.unwrap_or_else(|e| panic!("failed to parse fen '{fen}': {e:?}"));
|
||||
let board = cmd_position_moves(tokens, board);
|
||||
|
||||
return board;
|
||||
state.board = board;
|
||||
return;
|
||||
}
|
||||
"startpos" => {
|
||||
let board = Board::starting_pos();
|
||||
let board = cmd_position_moves(tokens, board);
|
||||
|
||||
return board;
|
||||
state.board = board;
|
||||
return;
|
||||
}
|
||||
_ => ignore!(),
|
||||
}
|
||||
}
|
||||
|
||||
panic!("position command was empty")
|
||||
eprintln!("cmd_position: position command was empty")
|
||||
}
|
||||
|
||||
/// Play the game.
|
||||
fn cmd_go(
|
||||
mut tokens: std::str::SplitWhitespace<'_>,
|
||||
board: &mut Board,
|
||||
cache: &mut TranspositionTable,
|
||||
) {
|
||||
// interface-to-engine
|
||||
let (tx1, rx) = channel();
|
||||
let tx2 = tx1.clone();
|
||||
fn cmd_go(mut tokens: std::str::SplitWhitespace<'_>, state: &mut MainState) {
|
||||
let mut wtime = 0;
|
||||
let mut btime = 0;
|
||||
|
||||
// can expect a 1sec soft timeout to result in more time than that of thinking
|
||||
let mut timeout = 1650;
|
||||
macro_rules! set_time {
|
||||
($color: expr, $var: ident) => {
|
||||
if let Some(time) = tokens.next() {
|
||||
if let Ok(time) = time.parse::<u64>() {
|
||||
$var = time;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
while let Some(token) = tokens.next() {
|
||||
match token {
|
||||
"wtime" => {
|
||||
if board.get_turn() == Color::White {
|
||||
if let Some(time) = tokens.next() {
|
||||
if let Ok(time) = time.parse::<u64>() {
|
||||
timeout = min(time / 50, timeout);
|
||||
}
|
||||
}
|
||||
}
|
||||
set_time!(Color::White, wtime)
|
||||
}
|
||||
"btime" => {
|
||||
if board.get_turn() == Color::Black {
|
||||
if let Some(time) = tokens.next() {
|
||||
if let Ok(time) = time.parse::<u64>() {
|
||||
timeout = min(time / 50, timeout);
|
||||
}
|
||||
}
|
||||
}
|
||||
set_time!(Color::Black, btime)
|
||||
}
|
||||
_ => ignore!(),
|
||||
}
|
||||
}
|
||||
|
||||
// timeout
|
||||
thread::spawn(move || {
|
||||
thread::sleep(Duration::from_millis(timeout));
|
||||
let _ = tx2.send(InterfaceMsg::Stop);
|
||||
});
|
||||
let (ourtime_ms, theirtime_ms) = if state.board.get_turn() == Color::White {
|
||||
(wtime, btime)
|
||||
} else {
|
||||
(btime, wtime)
|
||||
};
|
||||
|
||||
let mut engine_state = EngineState::new(SearchConfig::default(), rx, cache);
|
||||
let (line, eval) = best_line(board, &mut engine_state);
|
||||
state
|
||||
.tx_engine
|
||||
.send(MsgToEngine::Go(Box::new(GoMessage {
|
||||
board: state.board,
|
||||
config: state.config,
|
||||
time_lims: TimeLimits::from_ourtime_theirtime(ourtime_ms, theirtime_ms),
|
||||
})))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let chosen = line.last().copied();
|
||||
/// Print static evaluation of the position.
|
||||
fn cmd_eval(mut _tokens: std::str::SplitWhitespace<'_>, state: &mut MainState) {
|
||||
let res = eval_metrics(&state.board);
|
||||
println!("STATIC EVAL (negative black, positive white):\n- pst: {}\n- king distance: {} ({} distance)\n- phase: {}\n- total: {}", res.pst_eval, res.king_distance_eval, res.king_distance, res.phase, res.total_eval);
|
||||
}
|
||||
|
||||
/// Root UCI parser.
|
||||
fn cmd_root(mut tokens: std::str::SplitWhitespace<'_>, state: &mut MainState) {
|
||||
while let Some(token) = tokens.next() {
|
||||
match token {
|
||||
"uci" => {
|
||||
println!("{}", cmd_uci());
|
||||
}
|
||||
"isready" => {
|
||||
println!("readyok");
|
||||
}
|
||||
"ucinewgame" => {
|
||||
if matches!(state.uci_mode.mode, UCIMode::Idle) {
|
||||
state.tx_engine.send(MsgToEngine::NewGame).unwrap();
|
||||
state.board = Board::starting_pos();
|
||||
}
|
||||
}
|
||||
"quit" => {
|
||||
exit(0);
|
||||
}
|
||||
"position" => {
|
||||
if matches!(state.uci_mode.mode, UCIMode::Idle) {
|
||||
cmd_position(tokens, state);
|
||||
}
|
||||
}
|
||||
"go" => {
|
||||
if state.uci_mode.transition(UCIModeTransition::Go).is_ok() {
|
||||
cmd_go(tokens, state);
|
||||
}
|
||||
}
|
||||
"stop" => {
|
||||
// actually setting state to stop happens when bestmove is received
|
||||
if matches!(state.uci_mode.mode, UCIMode::Think | UCIMode::Ponder) {
|
||||
state.tx_engine.send(MsgToEngine::Stop).unwrap();
|
||||
}
|
||||
}
|
||||
// non-standard command.
|
||||
"eval" => {
|
||||
cmd_eval(tokens, state);
|
||||
}
|
||||
_ => ignore!(),
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a bestmove.
|
||||
fn outp_bestmove(bestmove: MsgBestmove) {
|
||||
let chosen = bestmove.pv.last().copied();
|
||||
println!(
|
||||
"info pv{}",
|
||||
line.iter()
|
||||
bestmove
|
||||
.pv
|
||||
.iter()
|
||||
.rev()
|
||||
.map(|mv| mv.to_uci_algebraic())
|
||||
.fold(String::new(), |a, b| a + " " + &b)
|
||||
);
|
||||
match eval {
|
||||
match bestmove.eval {
|
||||
SearchEval::Checkmate(n) => println!("info score mate {}", n / 2),
|
||||
SearchEval::Centipawns(eval) => {
|
||||
println!("info score cp {}", eval,)
|
||||
}
|
||||
SearchEval::Stopped => {
|
||||
panic!("info string ERROR: stopped search")
|
||||
}
|
||||
}
|
||||
match chosen {
|
||||
Some(mv) => println!("bestmove {}", mv.to_uci_algebraic()),
|
||||
@ -153,51 +230,118 @@ fn cmd_go(
|
||||
}
|
||||
}
|
||||
|
||||
/// Print static evaluation of the position.
|
||||
fn cmd_eval(mut _tokens: std::str::SplitWhitespace<'_>, board: &mut Board) {
|
||||
let res = eval_metrics(board);
|
||||
println!("STATIC EVAL (negative black, positive white):\n- pst: {}\n- king distance: {} ({} distance)\n- phase: {}\n- total: {}", res.pst_eval, res.king_distance_eval, res.king_distance, res.phase, res.total_eval);
|
||||
/// The "Stdin" thread to read stdin while avoiding blocking
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `tx_main`: channel write end to send lines to
|
||||
fn task_stdin_reader(tx_main: Sender<MsgToMain>) {
|
||||
thread::spawn(move || {
|
||||
let stdin = io::stdin();
|
||||
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
stdin.read_line(&mut line).unwrap();
|
||||
tx_main.send(MsgToMain::StdinLine(line)).unwrap();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let stdin = io::stdin();
|
||||
/// The "Engine" thread that does all the computation.
|
||||
fn task_engine(tx_main: Sender<MsgToMain>, rx_engine: Receiver<MsgToEngine>) {
|
||||
thread::spawn(move || {
|
||||
let conf = SearchConfig::default();
|
||||
let mut state = EngineState::new(
|
||||
conf,
|
||||
rx_engine,
|
||||
TranspositionTable::new(conf.transposition_size),
|
||||
TimeLimits::default(),
|
||||
);
|
||||
|
||||
let mut board = Board::starting_pos();
|
||||
let mut transposition_table = TranspositionTable::new(24);
|
||||
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
stdin.read_line(&mut line).unwrap();
|
||||
let mut tokens = line.split_whitespace();
|
||||
while let Some(token) = tokens.next() {
|
||||
match token {
|
||||
"uci" => {
|
||||
println!("{}", cmd_uci());
|
||||
loop {
|
||||
let msg = state.rx_engine.recv().unwrap();
|
||||
match msg {
|
||||
MsgToEngine::Go(msg_box) => {
|
||||
let mut board = msg_box.board;
|
||||
state.config = msg_box.config;
|
||||
state.time_lims = msg_box.time_lims;
|
||||
let (pv, eval) = best_line(&mut board, &mut state);
|
||||
tx_main
|
||||
.send(MsgToMain::Bestmove(MsgBestmove { pv, eval }))
|
||||
.unwrap();
|
||||
}
|
||||
"isready" => {
|
||||
println!("readyok");
|
||||
MsgToEngine::Stop => {}
|
||||
MsgToEngine::NewGame => {
|
||||
state.wipe_state();
|
||||
}
|
||||
"ucinewgame" => {
|
||||
board = Board::starting_pos();
|
||||
transposition_table = TranspositionTable::new(24);
|
||||
}
|
||||
"quit" => {
|
||||
return;
|
||||
}
|
||||
"position" => {
|
||||
board = cmd_position(tokens);
|
||||
}
|
||||
"go" => {
|
||||
cmd_go(tokens, &mut board, &mut transposition_table);
|
||||
}
|
||||
// non-standard command.
|
||||
"eval" => {
|
||||
cmd_eval(tokens, &mut board);
|
||||
}
|
||||
_ => ignore!(),
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
break;
|
||||
/// State contained within the main thread.
|
||||
///
|
||||
/// This struct helps pass around this thread state.
|
||||
struct MainState {
|
||||
/// Channel to send messages to Engine.
|
||||
tx_engine: Sender<MsgToEngine>,
|
||||
/// Channel to receive messages from Engine and Stdin.
|
||||
rx_main: Receiver<MsgToMain>,
|
||||
/// Chessboard.
|
||||
board: Board,
|
||||
/// Engine configuration settings.
|
||||
config: SearchConfig,
|
||||
/// UCI mode state machine
|
||||
uci_mode: UCIModeMachine,
|
||||
}
|
||||
|
||||
impl MainState {
|
||||
fn new(
|
||||
tx_engine: Sender<MsgToEngine>,
|
||||
rx_main: Receiver<MsgToMain>,
|
||||
board: Board,
|
||||
config: SearchConfig,
|
||||
uci_mode: UCIModeMachine,
|
||||
) -> Self {
|
||||
Self {
|
||||
tx_engine,
|
||||
rx_main,
|
||||
board,
|
||||
config,
|
||||
uci_mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The "Main" thread.
|
||||
fn main() {
|
||||
let (tx_main, rx_main) = channel();
|
||||
task_stdin_reader(tx_main.clone());
|
||||
|
||||
let (tx_engine, rx_engine) = channel();
|
||||
task_engine(tx_main, rx_engine);
|
||||
|
||||
let mut state = MainState::new(
|
||||
tx_engine,
|
||||
rx_main,
|
||||
Board::starting_pos(),
|
||||
SearchConfig::default(),
|
||||
UCIModeMachine::default(),
|
||||
);
|
||||
|
||||
loop {
|
||||
let msg = state.rx_main.recv().unwrap();
|
||||
match msg {
|
||||
MsgToMain::StdinLine(line) => {
|
||||
let tokens = line.split_whitespace();
|
||||
cmd_root(tokens, &mut state);
|
||||
}
|
||||
MsgToMain::Bestmove(msg_bestmove) => {
|
||||
state
|
||||
.uci_mode
|
||||
.transition(UCIModeTransition::Bestmove)
|
||||
.unwrap();
|
||||
outp_bestmove(msg_bestmove);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -16,5 +16,6 @@ Copyright © 2024 dogeystamp <dogeystamp@disroot.org>
|
||||
pub use crate::eval::{eval_metrics, EvalMetrics};
|
||||
pub use crate::fen::{FromFen, ToFen};
|
||||
pub use crate::movegen::{FromUCIAlgebraic, Move, MoveGen, ToUCIAlgebraic};
|
||||
pub use crate::search::{best_line, best_move, InterfaceMsg, SearchEval, TranspositionTable, EngineState, SearchConfig};
|
||||
pub use crate::search::{best_line, best_move, SearchEval, TranspositionTable, EngineState, SearchConfig, TimeLimits};
|
||||
pub use crate::{Board, Color, BOARD_HEIGHT, BOARD_WIDTH, N_COLORS, N_PIECES, N_SQUARES};
|
||||
pub use crate::coordination::{UCIMode, UCIModeTransition, UCIModeMachine, MsgBestmove, MsgToMain, MsgToEngine, GoMessage};
|
||||
|
263
src/search.rs
263
src/search.rs
@ -13,12 +13,14 @@ Copyright © 2024 dogeystamp <dogeystamp@disroot.org>
|
||||
|
||||
//! Game-tree search.
|
||||
|
||||
use crate::coordination::MsgToEngine;
|
||||
use crate::eval::{Eval, EvalInt};
|
||||
use crate::hash::ZobristTable;
|
||||
use crate::movegen::{Move, MoveGen};
|
||||
use crate::{Board, Piece};
|
||||
use std::cmp::max;
|
||||
use std::cmp::{max, min};
|
||||
use std::sync::mpsc;
|
||||
use std::time::{Instant, Duration};
|
||||
|
||||
// min can't be represented as positive
|
||||
const EVAL_WORST: EvalInt = -(EvalInt::MAX);
|
||||
@ -43,6 +45,8 @@ pub enum SearchEval {
|
||||
Checkmate(i8),
|
||||
/// Centipawn score.
|
||||
Centipawns(EvalInt),
|
||||
/// Search was hard-stopped.
|
||||
Stopped,
|
||||
}
|
||||
|
||||
impl SearchEval {
|
||||
@ -58,6 +62,7 @@ impl SearchEval {
|
||||
}
|
||||
}
|
||||
SearchEval::Centipawns(eval) => Self::Centipawns(-eval),
|
||||
SearchEval::Stopped => SearchEval::Stopped,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -74,6 +79,7 @@ impl From<SearchEval> for EvalInt {
|
||||
}
|
||||
}
|
||||
SearchEval::Centipawns(eval) => eval,
|
||||
SearchEval::Stopped => panic!("Attempted to evaluate a halted search"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -96,11 +102,13 @@ impl PartialOrd for SearchEval {
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct SearchConfig {
|
||||
/// Enable alpha-beta pruning.
|
||||
alpha_beta_on: bool,
|
||||
pub alpha_beta_on: bool,
|
||||
/// Limit regular search depth
|
||||
depth: usize,
|
||||
pub depth: usize,
|
||||
/// Enable transposition table.
|
||||
enable_trans_table: bool,
|
||||
pub enable_trans_table: bool,
|
||||
/// Transposition table size (2^n where this is n)
|
||||
pub transposition_size: usize,
|
||||
}
|
||||
|
||||
impl Default for SearchConfig {
|
||||
@ -110,6 +118,7 @@ impl Default for SearchConfig {
|
||||
// try to make this even to be more conservative and avoid horizon problem
|
||||
depth: 10,
|
||||
enable_trans_table: true,
|
||||
transposition_size: 24,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -155,11 +164,35 @@ fn move_priority(board: &mut Board, mv: &Move) -> EvalInt {
|
||||
/// The best line (in reverse move order), and its corresponding absolute eval for the current player.
|
||||
fn minmax(
|
||||
board: &mut Board,
|
||||
engine_state: &mut EngineState<'_>,
|
||||
state: &mut EngineState,
|
||||
depth: usize,
|
||||
alpha: Option<EvalInt>,
|
||||
beta: Option<EvalInt>,
|
||||
) -> (Vec<Move>, SearchEval) {
|
||||
// these operations are relatively expensive, so only run them occasionally
|
||||
if state.node_count % (1 << 16) == 0 {
|
||||
// respect the hard stop if given
|
||||
match state.rx_engine.try_recv() {
|
||||
Ok(msg) => match msg {
|
||||
MsgToEngine::Go(_) => panic!("received go while thinking"),
|
||||
MsgToEngine::Stop => {
|
||||
return (Vec::new(), SearchEval::Stopped);
|
||||
}
|
||||
MsgToEngine::NewGame => panic!("received newgame while thinking"),
|
||||
},
|
||||
Err(e) => match e {
|
||||
mpsc::TryRecvError::Empty => {}
|
||||
mpsc::TryRecvError::Disconnected => panic!("thread Main stopped"),
|
||||
},
|
||||
}
|
||||
|
||||
if let Some(hard) = state.time_lims.hard {
|
||||
if Instant::now() > hard {
|
||||
return (Vec::new(), SearchEval::Stopped);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// default to worst, then gradually improve
|
||||
let mut alpha = alpha.unwrap_or(EVAL_WORST);
|
||||
// our best is their worst
|
||||
@ -179,8 +212,8 @@ fn minmax(
|
||||
.collect();
|
||||
|
||||
// get transposition table entry
|
||||
if engine_state.config.enable_trans_table {
|
||||
if let Some(entry) = &engine_state.cache[board.zobrist] {
|
||||
if state.config.enable_trans_table {
|
||||
if let Some(entry) = &state.cache[board.zobrist] {
|
||||
// the entry has a deeper knowledge than we do, so follow its best move exactly instead of
|
||||
// just prioritizing what it thinks is best
|
||||
if entry.depth >= depth {
|
||||
@ -209,8 +242,13 @@ fn minmax(
|
||||
|
||||
for (_priority, mv) in mvs {
|
||||
let anti_mv = mv.make(board);
|
||||
let (continuation, score) =
|
||||
minmax(board, engine_state, depth - 1, Some(-beta), Some(-alpha));
|
||||
let (continuation, score) = minmax(board, state, depth - 1, Some(-beta), Some(-alpha));
|
||||
|
||||
// propagate hard stops
|
||||
if matches!(score, SearchEval::Stopped) {
|
||||
return (Vec::new(), SearchEval::Stopped);
|
||||
}
|
||||
|
||||
let abs_score = score.increment();
|
||||
if abs_score > abs_best {
|
||||
abs_best = abs_score;
|
||||
@ -219,7 +257,7 @@ fn minmax(
|
||||
}
|
||||
alpha = max(alpha, abs_best.into());
|
||||
anti_mv.unmake(board);
|
||||
if alpha >= beta && engine_state.config.alpha_beta_on {
|
||||
if alpha >= beta && state.config.alpha_beta_on {
|
||||
// alpha-beta prune.
|
||||
//
|
||||
// Beta represents the best eval that the other player can get in sibling branches
|
||||
@ -232,8 +270,8 @@ fn minmax(
|
||||
|
||||
if let Some(best_move) = best_move {
|
||||
best_continuation.push(best_move);
|
||||
if engine_state.config.enable_trans_table {
|
||||
engine_state.cache[board.zobrist] = Some(TranspositionEntry {
|
||||
if state.config.enable_trans_table {
|
||||
state.cache[board.zobrist] = Some(TranspositionEntry {
|
||||
best_move,
|
||||
eval: abs_best,
|
||||
depth,
|
||||
@ -241,16 +279,10 @@ fn minmax(
|
||||
}
|
||||
}
|
||||
|
||||
state.node_count += 1;
|
||||
(best_continuation, abs_best)
|
||||
}
|
||||
|
||||
/// Messages from the interface to the search thread.
|
||||
pub enum InterfaceMsg {
|
||||
Stop,
|
||||
}
|
||||
|
||||
type InterfaceRx = mpsc::Receiver<InterfaceMsg>;
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct TranspositionEntry {
|
||||
/// best move found last time
|
||||
@ -264,131 +296,122 @@ pub struct TranspositionEntry {
|
||||
pub type TranspositionTable = ZobristTable<TranspositionEntry>;
|
||||
|
||||
/// Iteratively deepen search until it is stopped.
|
||||
fn iter_deep(board: &mut Board, engine_state: &mut EngineState<'_>) -> (Vec<Move>, SearchEval) {
|
||||
// don't interrupt a depth 1 search so that there's at least a move to be played
|
||||
let (mut prev_line, mut prev_eval) = minmax(board, engine_state, 1, None, None);
|
||||
for depth in 2..=engine_state.config.depth {
|
||||
let (line, eval) = minmax(board, engine_state, depth, None, None);
|
||||
fn iter_deep(board: &mut Board, state: &mut EngineState) -> (Vec<Move>, SearchEval) {
|
||||
// always preserve two lines (1 is most recent)
|
||||
let (mut line1, mut eval1) = minmax(board, state, 1, None, None);
|
||||
let (mut line2, mut eval2) = (line1.clone(), eval1);
|
||||
|
||||
match engine_state.interface.try_recv() {
|
||||
Ok(msg) => match msg {
|
||||
InterfaceMsg::Stop => {
|
||||
if depth & 1 == 1 && (EvalInt::from(eval) - EvalInt::from(prev_eval) > 300) {
|
||||
// be skeptical if we move last and we suddenly earn a lot of
|
||||
// centipawns. this may be a sign of horizon problem
|
||||
return (prev_line, prev_eval);
|
||||
} else {
|
||||
return (line, eval);
|
||||
}
|
||||
for depth in 2..=state.config.depth {
|
||||
let (line, eval) = minmax(board, state, depth, None, None);
|
||||
|
||||
let mut have_to_ret = false;
|
||||
// depth of the line we're about to return.
|
||||
// our knock-off "quiescence" is skeptical of odd depths, so we need to know this.
|
||||
let mut ret_depth = depth;
|
||||
|
||||
if matches!(eval, SearchEval::Stopped) {
|
||||
ret_depth -= 1;
|
||||
have_to_ret = true;
|
||||
} else {
|
||||
(line2, eval2) = (line1, eval1);
|
||||
(line1, eval1) = (line, eval);
|
||||
if let Some(soft_lim) = state.time_lims.soft {
|
||||
if Instant::now() > soft_lim {
|
||||
have_to_ret = true;
|
||||
}
|
||||
},
|
||||
Err(e) => match e {
|
||||
mpsc::TryRecvError::Empty => {}
|
||||
mpsc::TryRecvError::Disconnected => panic!("interface thread stopped"),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if have_to_ret {
|
||||
if ret_depth & 1 == 1 && (EvalInt::from(eval1) - EvalInt::from(eval2) > 300) {
|
||||
// be skeptical if we move last and we suddenly earn a lot of
|
||||
// centipawns. this may be a sign of horizon problem
|
||||
return (line2, eval2);
|
||||
} else {
|
||||
return (line1, eval1);
|
||||
}
|
||||
}
|
||||
(prev_line, prev_eval) = (line, eval);
|
||||
}
|
||||
(prev_line, prev_eval)
|
||||
(line1, eval1)
|
||||
}
|
||||
|
||||
/// Helper type to avoid retyping the same arguments into every function prototype
|
||||
pub struct EngineState<'a> {
|
||||
/// Configuration
|
||||
config: SearchConfig,
|
||||
/// Channel that can talk to the main thread
|
||||
interface: InterfaceRx,
|
||||
cache: &'a mut TranspositionTable,
|
||||
/// Deadlines for the engine to think of a move.
|
||||
#[derive(Default)]
|
||||
pub struct TimeLimits {
|
||||
/// The engine must respect this time limit. It will abort if this deadline is passed.
|
||||
pub hard: Option<Instant>,
|
||||
pub soft: Option<Instant>,
|
||||
}
|
||||
|
||||
impl<'a> EngineState<'a> {
|
||||
impl TimeLimits {
|
||||
/// Make time limits based on wtime, btime (but color-independent).
|
||||
pub fn from_ourtime_theirtime(ourtime_ms: u64, _theirtime_ms: u64) -> TimeLimits {
|
||||
// hard timeout (max)
|
||||
let mut hard_ms = 100_000;
|
||||
// soft timeout (max)
|
||||
let mut soft_ms = 1_200;
|
||||
|
||||
let factor = if ourtime_ms > 5_000 { 10 } else { 40 };
|
||||
hard_ms = min(ourtime_ms / factor, hard_ms);
|
||||
soft_ms = min(ourtime_ms / 50, soft_ms);
|
||||
|
||||
let hard_limit = Instant::now() + Duration::from_millis(hard_ms);
|
||||
let soft_limit = Instant::now() + Duration::from_millis(soft_ms);
|
||||
|
||||
TimeLimits {
|
||||
hard: Some(hard_limit),
|
||||
soft: Some(soft_limit),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper type to avoid retyping the same arguments into every function prototype.
|
||||
///
|
||||
/// This should be owned outside the actual thinking part so that the engine can remember state
|
||||
/// between moves.
|
||||
pub struct EngineState {
|
||||
pub config: SearchConfig,
|
||||
/// Main -> Engine channel receiver
|
||||
pub rx_engine: mpsc::Receiver<MsgToEngine>,
|
||||
pub cache: TranspositionTable,
|
||||
/// Nodes traversed (i.e. number of times minmax called)
|
||||
pub node_count: usize,
|
||||
pub time_lims: TimeLimits,
|
||||
}
|
||||
|
||||
impl EngineState {
|
||||
pub fn new(
|
||||
config: SearchConfig,
|
||||
interface: InterfaceRx,
|
||||
cache: &'a mut TranspositionTable,
|
||||
interface: mpsc::Receiver<MsgToEngine>,
|
||||
cache: TranspositionTable,
|
||||
time_lims: TimeLimits,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
interface,
|
||||
rx_engine: interface,
|
||||
cache,
|
||||
node_count: 0,
|
||||
time_lims,
|
||||
}
|
||||
}
|
||||
|
||||
/// Wipe state between different games.
|
||||
///
|
||||
/// Configuration is preserved.
|
||||
pub fn wipe_state(&mut self) {
|
||||
self.cache = TranspositionTable::new(self.config.transposition_size);
|
||||
self.node_count = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the best line (in reverse order) and its evaluation.
|
||||
pub fn best_line(board: &mut Board, engine_state: &mut EngineState<'_>) -> (Vec<Move>, SearchEval) {
|
||||
pub fn best_line(board: &mut Board, engine_state: &mut EngineState) -> (Vec<Move>, SearchEval) {
|
||||
let (line, eval) = iter_deep(board, engine_state);
|
||||
(line, eval)
|
||||
}
|
||||
|
||||
/// Find the best move.
|
||||
pub fn best_move(board: &mut Board, engine_state: &mut EngineState<'_>) -> Option<Move> {
|
||||
pub fn best_move(board: &mut Board, engine_state: &mut EngineState) -> Option<Move> {
|
||||
let (line, _eval) = best_line(board, engine_state);
|
||||
line.last().copied()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::fen::{FromFen, ToFen};
|
||||
use crate::movegen::ToUCIAlgebraic;
|
||||
|
||||
/// Theoretically, alpha-beta pruning should not affect the result of minmax.
|
||||
#[test]
|
||||
fn alpha_beta_same_result() {
|
||||
let test_cases = [
|
||||
"r2q1rk1/1bp1pp1p/p2p2p1/1p1P2P1/2n1P3/3Q1P2/PbPBN2P/3RKB1R b K - 5 15",
|
||||
"r1b1k2r/p1qpppbp/1p4pn/2B3N1/1PP1P3/2P5/P4PPP/RN1QR1K1 w kq - 0 14",
|
||||
];
|
||||
for fen in test_cases {
|
||||
let mut board = Board::from_fen(fen).unwrap();
|
||||
let (_tx, _rx) = mpsc::channel();
|
||||
let mut _cache = ZobristTable::new(0);
|
||||
let mut engine_state = EngineState::new(
|
||||
SearchConfig {
|
||||
alpha_beta_on: false,
|
||||
depth: 3,
|
||||
enable_trans_table: false,
|
||||
},
|
||||
_rx,
|
||||
&mut _cache,
|
||||
);
|
||||
|
||||
let mv_no_prune = best_move(
|
||||
&mut board,
|
||||
&mut engine_state,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(board.to_fen(), fen);
|
||||
|
||||
let (_tx, _rx) = mpsc::channel();
|
||||
let mut engine_state = EngineState::new(
|
||||
SearchConfig {
|
||||
alpha_beta_on: true,
|
||||
depth: 3,
|
||||
enable_trans_table: false,
|
||||
},
|
||||
_rx,
|
||||
&mut _cache,
|
||||
);
|
||||
|
||||
let mv_with_prune = best_move(
|
||||
&mut board,
|
||||
&mut engine_state,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(board.to_fen(), fen);
|
||||
|
||||
println!(
|
||||
"without ab prune got {}, otherwise {}, fen {}",
|
||||
mv_no_prune.to_uci_algebraic(),
|
||||
mv_with_prune.to_uci_algebraic(),
|
||||
fen
|
||||
);
|
||||
|
||||
assert_eq!(mv_no_prune, mv_with_prune);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user