diff --git a/src/hash.rs b/src/hash.rs index fe898ee..469066e 100644 --- a/src/hash.rs +++ b/src/hash.rs @@ -17,6 +17,8 @@ use crate::random::Pcg64Random; use crate::{ Board, CastleRights, ColPiece, Color, Square, BOARD_WIDTH, N_COLORS, N_PIECES, N_SQUARES, }; +use std::ops::Index; +use std::ops::IndexMut; const PIECE_KEYS: [[[u64; N_SQUARES]; N_PIECES]; N_COLORS] = [Pcg64Random::new(11).random_arr_2d_64(), Pcg64Random::new(22).random_arr_2d_64()]; @@ -79,6 +81,60 @@ impl Zobrist { pos.zobrist.toggle_castle(&pos.castle); pos.zobrist.toggle_turn(pos.turn); } + + /// Convert hash to an index. + fn truncate_hash(&self, size: usize) -> usize { + (self.hash & ((1 << size) - 1)) as usize + } +} + +/// Map that takes Zobrist hashes as keys. +/// +/// Heap allocated (it's a vector). +#[derive(Debug)] +pub struct ZobristTable { + data: Vec<(Zobrist, Option)>, + size: usize, +} + +impl ZobristTable { + /// Create a table with 2^n entries. + pub fn new(size: usize) -> Self { + assert!(size <= 27, "Attempted to make 2^{size} entry table; aborting to avoid excessive memory usage."); + ZobristTable { + data: vec![(Zobrist { hash: 0 }, None); 1 << size], + size, + } + } +} + +impl IndexMut for ZobristTable { + fn index_mut(&mut self, zobrist: Zobrist) -> &mut Self::Output { + let idx = zobrist.truncate_hash(self.size); + if self.data[idx].0 == zobrist { + &mut self.data[idx].1 + } else { + // miss, overwrite + self.data[idx].0 = zobrist; + self.data[idx].1 = None; + &mut self.data[idx].1 + } + } +} + +impl Index for ZobristTable { + type Output = Option; + + fn index(&self, zobrist: Zobrist) -> &Self::Output { + let idx = zobrist.truncate_hash(self.size); + let data = &self.data[idx]; + if data.0 == zobrist { + &data.1 + } else { + // miss + &None + } + } } #[cfg(test)] @@ -87,7 +143,7 @@ mod tests { use crate::fen::FromFen; use crate::movegen::{FromUCIAlgebraic, Move}; - /// Zobrist hashes of the same positions should be the same. (basic sanity test) + /// Zobrist hashes, and transposition table elements of the same positions should be the same. (basic sanity test) #[test] fn test_zobrist_equality() { let test_cases = [ @@ -114,16 +170,71 @@ mod tests { ]; for (pos1_fen, pos2_fen, mv_uci) in test_cases { eprintln!("tc: {}", pos1_fen); + let mut table = ZobristTable::::new(4); let mut pos1 = Board::from_fen(pos1_fen).unwrap(); let hash1_orig = pos1.zobrist; + assert_eq!(table[pos1.zobrist], None); + table[pos1.zobrist] = Some(100); eprintln!("refreshing board 2 '{}'", pos2_fen); let pos2 = Board::from_fen(pos2_fen).unwrap(); + table[pos2.zobrist] = Some(200); eprintln!("making mv {}", mv_uci); let mv = Move::from_uci_algebraic(mv_uci).unwrap(); let anti_mv = mv.make(&mut pos1); assert_eq!(pos1.zobrist, pos2.zobrist); + assert_eq!(table[pos1.zobrist], Some(200)); anti_mv.unmake(&mut pos1); assert_eq!(pos1.zobrist, hash1_orig); + assert_eq!(table[pos1.zobrist], Some(100)); } } + + // test that positions are equal when they loop back to the start + #[test] + fn test_zobrist_loops() { + let test_cases = [ + ( + "4k3/4r3/8/8/8/8/3R4/3K4 w - - 0 1", + "d2f2 e7f7 f2d2 f7e7", + ), + ( + "4k3/4r3/8/8/8/8/3R4/3K4 w - - 0 1", + "d2f2 e7f7 f2d2 f7e7 d2f2 e7f7 f2d2 f7e7 d2f2 e7f7 f2d2 f7e7 d2f2 e7f7 f2d2 f7e7 d2f2 e7f7 f2d2 f7e7 d2f2 e7f7 f2d2 f7e7", + ), + ]; + + for (fen, mvs_str) in test_cases { + let pos_orig = Board::from_fen(fen).unwrap(); + let mut pos = pos_orig.clone(); + for mv_str in mvs_str.split_whitespace() { + let mv = Move::from_uci_algebraic(mv_str).unwrap(); + mv.make(&mut pos); + } + pos.half_moves = pos_orig.half_moves; + pos.full_moves = pos_orig.full_moves; + assert_eq!(pos, pos_orig, "test case is incorrect, position should loop back to the original"); + assert_eq!(pos.zobrist, pos_orig.zobrist); + } + } + + #[test] + fn test_table() { + let mut table = ZobristTable::::new(4); + + macro_rules! z { + ($i: expr) => { Zobrist { hash: $i } } + } + + let big_number = 1 << 62; + + table[z!(big_number + 3)] = Some(4); + table[z!(big_number + 19)] = Some(5); + + // clobbered by newer entry + assert_eq!(table[z!(big_number + 3)], None); + + assert_eq!(table[z!(big_number + 19)], Some(5)); + + eprintln!("{table:?}"); + } } diff --git a/src/movegen.rs b/src/movegen.rs index e2869ed..26d63c0 100644 --- a/src/movegen.rs +++ b/src/movegen.rs @@ -14,7 +14,7 @@ Copyright © 2024 dogeystamp //! Move generation. use crate::fen::ToFen; -use crate::hash::Zobrist; +use crate::hash::{Zobrist, ZobristTable}; use crate::{ Board, CastleRights, ColPiece, Color, Piece, Square, SquareError, BOARD_HEIGHT, BOARD_WIDTH, N_SQUARES, @@ -770,8 +770,13 @@ impl MoveGenInternal for Board { } } -/// How many nodes at depth N can be reached from this position. -pub fn perft(depth: usize, pos: &mut Board) -> usize { +fn perft_internal(depth: usize, pos: &mut Board, cache: &mut ZobristTable<(usize, usize)>) -> usize { + if let Some((ans, cache_depth)) = cache[pos.zobrist] { + if depth == cache_depth { + return ans; + } + } + if depth == 0 { return 1; }; @@ -781,13 +786,20 @@ pub fn perft(depth: usize, pos: &mut Board) -> usize { let moves: Vec = pos.gen_moves().into_iter().collect(); for mv in moves { let anti_move = mv.make(pos); - ans += perft(depth - 1, pos); + ans += perft_internal(depth - 1, pos, cache); anti_move.unmake(pos); } + cache[pos.zobrist] = Some((ans, depth)); ans } +/// How many nodes at depth N can be reached from this position. +pub fn perft(depth: usize, pos: &mut Board) -> usize { + let mut cache = ZobristTable::new(23); + perft_internal(depth, pos, &mut cache) +} + #[cfg(test)] mod tests { use super::*;