perf: repetition draw avoidance now only looks at necessary moves

uses half-move counter.
This commit is contained in:
dogeystamp 2024-12-29 22:32:00 -05:00
parent 9995f13693
commit 612e6ffc15
No known key found for this signature in database
3 changed files with 46 additions and 18 deletions

View File

@ -19,10 +19,10 @@ use std::str::FromStr;
pub mod coordination; pub mod coordination;
pub mod eval; pub mod eval;
pub mod nnue;
pub mod fen; pub mod fen;
mod hash; mod hash;
pub mod movegen; pub mod movegen;
pub mod nnue;
pub mod random; pub mod random;
pub mod search; pub mod search;
@ -554,7 +554,7 @@ const HISTORY_SIZE: usize = 100;
impl BoardHistory { impl BoardHistory {
/// Counts occurences of this hash in the history. /// Counts occurences of this hash in the history.
fn count(&self, hash: Zobrist) -> usize { fn _count(&self, hash: Zobrist) -> usize {
let mut ans = 0; let mut ans = 0;
let mut i = self.ptr_start; let mut i = self.ptr_start;
@ -568,6 +568,20 @@ impl BoardHistory {
ans ans
} }
/// Find if there are at least `n` matches for a hash in the last `recent` plies.
fn at_least_in_recent(&self, mut n: usize, recent: usize, hash: Zobrist) -> bool {
let mut i = self.ptr_end - recent;
while i != self.ptr_end && n > 0 {
if self.hashes[usize::from(i)] == hash {
n -= 1;
}
i += 1;
}
n == 0
}
/// Add (push) hash to history. /// Add (push) hash to history.
fn push(&mut self, hash: Zobrist) { fn push(&mut self, hash: Zobrist) {
self.hashes[usize::from(self.ptr_end)] = hash; self.hashes[usize::from(self.ptr_end)] = hash;
@ -633,6 +647,12 @@ impl Board {
self.history.push(self.zobrist); self.history.push(self.zobrist);
} }
/// Is this position a draw by three repetitions?
pub fn is_repetition(&mut self) -> bool {
self.history
.at_least_in_recent(2, self.half_moves, self.zobrist)
}
/// Get iterator over all squares. /// Get iterator over all squares.
pub fn squares() -> impl Iterator<Item = Square> { pub fn squares() -> impl Iterator<Item = Square> {
(0..N_SQUARES).map(Square::try_from).map(|x| x.unwrap()) (0..N_SQUARES).map(Square::try_from).map(|x| x.unwrap())
@ -948,20 +968,30 @@ mod tests {
history.push(board.zobrist); history.push(board.zobrist);
} }
assert_eq!(history.count(board.zobrist), HISTORY_SIZE - 1); assert_eq!(history._count(board.zobrist), HISTORY_SIZE - 1);
assert!(history.at_least_in_recent(1, 1, board.zobrist));
assert!(history.at_least_in_recent(2, 3, board.zobrist));
assert!(history.at_least_in_recent(1, 3, board.zobrist));
let board_empty = Board::default(); let board_empty = Board::default();
history.push(board_empty.zobrist); history.push(board_empty.zobrist);
assert_eq!(history.count(board.zobrist), HISTORY_SIZE - 2); assert!(!history.at_least_in_recent(1, 1, board.zobrist));
assert_eq!(history.count(board_empty.zobrist), 1); assert!(history.at_least_in_recent(1, 2, board.zobrist));
assert_eq!(history._count(board.zobrist), HISTORY_SIZE - 2);
assert_eq!(history._count(board_empty.zobrist), 1);
assert!(history.at_least_in_recent(1, 3, board.zobrist));
assert!(history.at_least_in_recent(1, 20, board_empty.zobrist));
assert!(history.at_least_in_recent(1, 15, board_empty.zobrist));
assert!(history.at_least_in_recent(1, 1, board_empty.zobrist));
for _ in 0..3 { for _ in 0..3 {
history.push(board_empty.zobrist); history.push(board_empty.zobrist);
} }
assert_eq!(history.count(board_empty.zobrist), 4); assert_eq!(history._count(board_empty.zobrist), 4);
assert_eq!(history.count(board.zobrist), HISTORY_SIZE - 5); assert_eq!(history._count(board.zobrist), HISTORY_SIZE - 5);
} }
#[test] #[test]
@ -982,7 +1012,8 @@ mod tests {
} }
// this is the third occurence, but beforehand there are two occurences // this is the third occurence, but beforehand there are two occurences
assert_eq!(board.history.count(board.zobrist), 2); assert_eq!(board.history._count(board.zobrist), 2);
assert!(board.is_repetition(), "fen: {}", board.to_fen());
} }
/// engine should take advantage of the three time repetition rule /// engine should take advantage of the three time repetition rule
@ -1032,7 +1063,7 @@ mod tests {
expected_bestmv.make(&mut board); expected_bestmv.make(&mut board);
eprintln!( eprintln!(
"after expected mv, board repeated {} times", "after expected mv, board repeated {} times",
board.history.count(board.zobrist) board.history._count(board.zobrist)
); );
assert_eq!( assert_eq!(

View File

@ -375,15 +375,13 @@ fn task_engine(tx_main: Sender<MsgToMain>, rx_engine: Receiver<MsgToEngine>) {
let board_tensor = chess_inator::nnue::InputTensor::from_board(&board); let board_tensor = chess_inator::nnue::InputTensor::from_board(&board);
let abs_eval = EvalInt::from(eval) * EvalInt::from(board.get_turn().sign()); let abs_eval = EvalInt::from(eval) * EvalInt::from(board.get_turn().sign());
info.push(format!("NNUETrainInfo {} {} {}", is_quiet, abs_eval, {board_tensor})) info.push(format!("NNUETrainInfo {} {} {}", is_quiet, abs_eval, {
board_tensor
}))
} }
tx_main tx_main
.send(MsgToMain::Bestmove(MsgBestmove { .send(MsgToMain::Bestmove(MsgBestmove { pv, eval, info }))
pv,
eval,
info,
}))
.unwrap(); .unwrap();
} }
MsgToEngine::Stop => {} MsgToEngine::Stop => {}

View File

@ -224,7 +224,7 @@ fn minmax(board: &mut Board, state: &mut EngineState, mm: MinmaxState) -> (Vec<M
} }
} }
let is_repetition_draw = board.history.count(board.zobrist) >= 2; let is_repetition_draw = board.is_repetition();
let phase_factor = EvalInt::from(board.eval.min_maj_pieces / 5); let phase_factor = EvalInt::from(board.eval.min_maj_pieces / 5);
// positive here since we're looking from the opposite perspective. // positive here since we're looking from the opposite perspective.
// if white caused a draw, then we'd be black here. // if white caused a draw, then we'd be black here.
@ -311,7 +311,6 @@ fn minmax(board: &mut Board, state: &mut EngineState, mm: MinmaxState) -> (Vec<M
// sort moves by decreasing priority // sort moves by decreasing priority
mvs.sort_unstable_by_key(|mv| -mv.0); mvs.sort_unstable_by_key(|mv| -mv.0);
// default to worst, then gradually improve // default to worst, then gradually improve
let mut alpha = mm.alpha.unwrap_or(EVAL_WORST); let mut alpha = mm.alpha.unwrap_or(EVAL_WORST);
// our best is their worst // our best is their worst