Compare commits

...

5 Commits

Author SHA1 Message Date
8075e19bee feat: basic transposition table 2024-11-17 14:02:47 -05:00
6b672f83dc feat: perft zobrist cache 2024-11-17 13:07:44 -05:00
36753f6ecb refactor: random
uses a rust 1.83 feature so that the const fn works better
2024-11-17 10:19:05 -05:00
71594051f5 chore: fix warnings 2024-11-16 21:56:53 -05:00
4ca186b67e chore: fmt 2024-11-16 21:55:56 -05:00
10 changed files with 249 additions and 87 deletions

View File

@ -9,6 +9,8 @@ Features:
- Tapered midgame-endgame evaluation
- UCI compatibility
- Iterative deepening
- Transposition table (Zobrist hashing)
- Currently only stores best move.
## instructions

2
rust-toolchain.toml Normal file
View File

@ -0,0 +1,2 @@
[toolchain]
channel = "beta-2024-11-17"

View File

@ -15,12 +15,12 @@ Copyright © 2024 dogeystamp <dogeystamp@disroot.org>
use chess_inator::eval::eval_metrics;
use chess_inator::fen::FromFen;
use chess_inator::movegen::{FromUCIAlgebraic, Move, ToUCIAlgebraic};
use chess_inator::search::{best_line, InterfaceMsg, SearchEval};
use chess_inator::search::{best_line, InterfaceMsg, SearchEval, TranspositionTable};
use chess_inator::Board;
use std::io;
use std::sync::mpsc::channel;
use std::thread;
use std::time::{Duration, Instant};
use std::time::Duration;
/// UCI protocol says to ignore any unknown words.
///
@ -90,7 +90,7 @@ fn cmd_position(mut tokens: std::str::SplitWhitespace<'_>) -> Board {
}
/// Play the game.
fn cmd_go(mut _tokens: std::str::SplitWhitespace<'_>, board: &mut Board) {
fn cmd_go(mut _tokens: std::str::SplitWhitespace<'_>, board: &mut Board, cache: &mut Option<TranspositionTable>) {
// interface-to-engine
let (tx1, rx) = channel();
let tx2 = tx1.clone();
@ -101,7 +101,7 @@ fn cmd_go(mut _tokens: std::str::SplitWhitespace<'_>, board: &mut Board) {
let _ = tx2.send(InterfaceMsg::Stop);
});
let (line, eval) = best_line(board, None, Some(rx));
let (line, eval) = best_line(board, None, Some(rx), cache);
let chosen = line.last().copied();
println!(
@ -133,6 +133,7 @@ fn main() {
let stdin = io::stdin();
let mut board = Board::starting_pos();
let mut transposition_table = Some(TranspositionTable::new(24));
loop {
let mut line = String::new();
@ -156,7 +157,7 @@ fn main() {
board = cmd_position(tokens);
}
"go" => {
cmd_go(tokens, &mut board);
cmd_go(tokens, &mut board, &mut transposition_table);
}
// non-standard command.
"eval" => {

View File

@ -377,7 +377,8 @@ pub fn eval_metrics(board: &Board) -> EvalMetrics {
.map_or(0, |(k1, k2)| k1.manhattan(k2));
// attempt to minimize king distance for checkmates
let king_distance_eval = -advantage * i32::try_from(king_distance).unwrap() * max(7 - phase, 0) / 100;
let king_distance_eval =
-advantage * i32::try_from(king_distance).unwrap() * max(7 - phase, 0) / 100;
let eval = pst_eval + king_distance_eval;

View File

@ -11,8 +11,8 @@ You should have received a copy of the GNU General Public License along with che
Copyright © 2024 dogeystamp <dogeystamp@disroot.org>
*/
use crate::{Board, ColPiece, Color, Square, SquareIdx, BOARD_HEIGHT, BOARD_WIDTH};
use crate::hash::Zobrist;
use crate::{Board, ColPiece, Color, Square, SquareIdx, BOARD_HEIGHT, BOARD_WIDTH};
pub const START_POSITION: &str = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1";

View File

@ -13,22 +13,24 @@ Copyright © 2024 dogeystamp <dogeystamp@disroot.org>
//! Zobrist hash implementation.
use crate::random::{random_arr_2d_64, random_arr_64};
use crate::random::Pcg64Random;
use crate::{
Board, CastleRights, ColPiece, Color, Square, BOARD_WIDTH, N_COLORS, N_PIECES, N_SQUARES,
};
use std::ops::Index;
use std::ops::IndexMut;
const PIECE_KEYS: [[[u64; N_SQUARES]; N_PIECES]; N_COLORS] =
[random_arr_2d_64(11), random_arr_2d_64(22)];
[Pcg64Random::new(11).random_arr_2d_64(), Pcg64Random::new(22).random_arr_2d_64()];
// 4 bits in castle perms -> 16 keys
const CASTLE_KEYS: [u64; 16] = random_arr_64(33);
const CASTLE_KEYS: [u64; 16] = Pcg64Random::new(33).random_arr_64();
// ep can be specified by the file
const EP_KEYS: [u64; BOARD_WIDTH] = random_arr_64(44);
const EP_KEYS: [u64; BOARD_WIDTH] = Pcg64Random::new(44).random_arr_64();
// current turn
const COL_KEY: [u64; N_COLORS] = random_arr_64(55);
const COL_KEY: [u64; N_COLORS] = Pcg64Random::new(55).random_arr_64();
/// Zobrist hash state.
///
@ -79,6 +81,60 @@ impl Zobrist {
pos.zobrist.toggle_castle(&pos.castle);
pos.zobrist.toggle_turn(pos.turn);
}
/// Convert hash to an index.
fn truncate_hash(&self, size: usize) -> usize {
(self.hash & ((1 << size) - 1)) as usize
}
}
/// Map that takes Zobrist hashes as keys.
///
/// Heap allocated (it's a vector).
#[derive(Debug)]
pub struct ZobristTable<T> {
data: Vec<(Zobrist, Option<T>)>,
size: usize,
}
impl<T: Copy> ZobristTable<T> {
/// Create a table with 2^n entries.
pub fn new(size: usize) -> Self {
assert!(size <= 27, "Attempted to make 2^{size} entry table; aborting to avoid excessive memory usage.");
ZobristTable {
data: vec![(Zobrist { hash: 0 }, None); 1 << size],
size,
}
}
}
impl<T> IndexMut<Zobrist> for ZobristTable<T> {
fn index_mut(&mut self, zobrist: Zobrist) -> &mut Self::Output {
let idx = zobrist.truncate_hash(self.size);
if self.data[idx].0 == zobrist {
&mut self.data[idx].1
} else {
// miss, overwrite
self.data[idx].0 = zobrist;
self.data[idx].1 = None;
&mut self.data[idx].1
}
}
}
impl<T> Index<Zobrist> for ZobristTable<T> {
type Output = Option<T>;
fn index(&self, zobrist: Zobrist) -> &Self::Output {
let idx = zobrist.truncate_hash(self.size);
let data = &self.data[idx];
if data.0 == zobrist {
&data.1
} else {
// miss
&None
}
}
}
#[cfg(test)]
@ -87,7 +143,7 @@ mod tests {
use crate::fen::FromFen;
use crate::movegen::{FromUCIAlgebraic, Move};
/// Zobrist hashes of the same positions should be the same. (basic sanity test)
/// Zobrist hashes, and transposition table elements of the same positions should be the same. (basic sanity test)
#[test]
fn test_zobrist_equality() {
let test_cases = [
@ -114,16 +170,71 @@ mod tests {
];
for (pos1_fen, pos2_fen, mv_uci) in test_cases {
eprintln!("tc: {}", pos1_fen);
let mut table = ZobristTable::<usize>::new(4);
let mut pos1 = Board::from_fen(pos1_fen).unwrap();
let hash1_orig = pos1.zobrist;
assert_eq!(table[pos1.zobrist], None);
table[pos1.zobrist] = Some(100);
eprintln!("refreshing board 2 '{}'", pos2_fen);
let pos2 = Board::from_fen(pos2_fen).unwrap();
table[pos2.zobrist] = Some(200);
eprintln!("making mv {}", mv_uci);
let mv = Move::from_uci_algebraic(mv_uci).unwrap();
let anti_mv = mv.make(&mut pos1);
assert_eq!(pos1.zobrist, pos2.zobrist);
assert_eq!(table[pos1.zobrist], Some(200));
anti_mv.unmake(&mut pos1);
assert_eq!(pos1.zobrist, hash1_orig);
assert_eq!(table[pos1.zobrist], Some(100));
}
}
// test that positions are equal when they loop back to the start
#[test]
fn test_zobrist_loops() {
let test_cases = [
(
"4k3/4r3/8/8/8/8/3R4/3K4 w - - 0 1",
"d2f2 e7f7 f2d2 f7e7",
),
(
"4k3/4r3/8/8/8/8/3R4/3K4 w - - 0 1",
"d2f2 e7f7 f2d2 f7e7 d2f2 e7f7 f2d2 f7e7 d2f2 e7f7 f2d2 f7e7 d2f2 e7f7 f2d2 f7e7 d2f2 e7f7 f2d2 f7e7 d2f2 e7f7 f2d2 f7e7",
),
];
for (fen, mvs_str) in test_cases {
let pos_orig = Board::from_fen(fen).unwrap();
let mut pos = pos_orig.clone();
for mv_str in mvs_str.split_whitespace() {
let mv = Move::from_uci_algebraic(mv_str).unwrap();
mv.make(&mut pos);
}
pos.half_moves = pos_orig.half_moves;
pos.full_moves = pos_orig.full_moves;
assert_eq!(pos, pos_orig, "test case is incorrect, position should loop back to the original");
assert_eq!(pos.zobrist, pos_orig.zobrist);
}
}
#[test]
fn test_table() {
let mut table = ZobristTable::<usize>::new(4);
macro_rules! z {
($i: expr) => { Zobrist { hash: $i } }
}
let big_number = 1 << 62;
table[z!(big_number + 3)] = Some(4);
table[z!(big_number + 19)] = Some(5);
// clobbered by newer entry
assert_eq!(table[z!(big_number + 3)], None);
assert_eq!(table[z!(big_number + 19)], Some(5));
eprintln!("{table:?}");
}
}

View File

@ -21,7 +21,7 @@ pub mod eval;
pub mod fen;
mod hash;
pub mod movegen;
mod random;
pub mod random;
pub mod search;
use crate::fen::{FromFen, ToFen, START_POSITION};

View File

@ -13,8 +13,8 @@ Copyright © 2024 dogeystamp <dogeystamp@disroot.org>
//! Move generation.
use crate::hash::Zobrist;
use crate::fen::ToFen;
use crate::hash::{Zobrist, ZobristTable};
use crate::{
Board, CastleRights, ColPiece, Color, Piece, Square, SquareError, BOARD_HEIGHT, BOARD_WIDTH,
N_SQUARES,
@ -85,7 +85,7 @@ pub struct AntiMove {
dest: Square,
src: Square,
/// Captured piece, always assumed to be of enemy color.
pub (crate) cap: Option<Piece>,
pub(crate) cap: Option<Piece>,
move_type: AntiMoveType,
/// Half-move counter prior to this move
half_moves: usize,
@ -770,8 +770,13 @@ impl MoveGenInternal for Board {
}
}
/// How many nodes at depth N can be reached from this position.
pub fn perft(depth: usize, pos: &mut Board) -> usize {
fn perft_internal(depth: usize, pos: &mut Board, cache: &mut ZobristTable<(usize, usize)>) -> usize {
if let Some((ans, cache_depth)) = cache[pos.zobrist] {
if depth == cache_depth {
return ans;
}
}
if depth == 0 {
return 1;
};
@ -781,13 +786,20 @@ pub fn perft(depth: usize, pos: &mut Board) -> usize {
let moves: Vec<Move> = pos.gen_moves().into_iter().collect();
for mv in moves {
let anti_move = mv.make(pos);
ans += perft(depth - 1, pos);
ans += perft_internal(depth - 1, pos, cache);
anti_move.unmake(pos);
}
cache[pos.zobrist] = Some((ans, depth));
ans
}
/// How many nodes at depth N can be reached from this position.
pub fn perft(depth: usize, pos: &mut Board) -> usize {
let mut cache = ZobristTable::new(23);
perft_internal(depth, pos, &mut cache)
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -1,77 +1,82 @@
//! Rust port by dogeystamp <dogeystamp@disroot.org> of
//! the pcg64 dxsm random number generator (https://dotat.at/@/2023-06-21-pcg64-dxsm.html)
struct Pcg64Random {
pub struct Pcg64Random {
state: u128,
inc: u128,
}
/// Generates an array of random numbers.
///
/// The `rng` parameter only sets the initial state. This function is deterministic and pure.
///
/// # Returns
///
/// The array of random numbers, plus the RNG state at the end.
const fn pcg64_dxsm<const N: usize>(mut rng: Pcg64Random) -> ([u64; N], Pcg64Random) {
let mut ret = [0; N];
const MUL: u64 = 15750249268501108917;
let mut i = 0;
while i < N {
let state: u128 = rng.state;
rng.state = state.wrapping_mul(MUL as u128).wrapping_add(rng.inc);
let mut hi: u64 = (state >> 64) as u64;
let lo: u64 = (state | 1) as u64;
hi ^= hi >> 32;
hi &= MUL;
hi ^= hi >> 48;
hi = hi.wrapping_mul(lo);
ret[i] = hi;
i += 1;
}
(ret, rng)
}
/// Make an RNG state "sane".
const fn pcg64_seed(mut rng: Pcg64Random) -> Pcg64Random {
// ensure rng.inc is odd
rng.inc = (rng.inc << 1) | 1;
rng.state += rng.inc;
// one iteration of random
let (_, rng) = pcg64_dxsm::<1>(rng);
rng.rand();
rng
}
/// Generate array of random numbers, based on a seed.
///
/// This function is pure and deterministic, and also works at compile-time rather than at runtime.
///
/// Example (generate 10 random numbers):
///
///```rust
/// use crate::random::random_arr_64;
/// const ARR: [u64; 10] = random_arr_64(123456);
///```
pub const fn random_arr_64<const N: usize>(seed: u128) -> [u64; N] {
let rng = pcg64_seed(Pcg64Random {
// chosen by fair dice roll
state: 24437033748623976104561743679864923857,
inc: seed,
});
pcg64_dxsm(rng).0
}
/// Generate 2D array of random numbers based on a seed.
pub const fn random_arr_2d_64<const N: usize, const M: usize>(seed: u128) -> [[u64; N]; M] {
let mut ret = [[0; N]; M];
let mut i = 0;
while i < M {
ret[i] = random_arr_64(seed);
i += 1;
impl Pcg64Random {
pub const fn new(seed: u128) -> Self {
pcg64_seed(Pcg64Random {
// chosen by fair dice roll
state: 24437033748623976104561743679864923857,
inc: seed,
})
}
/// Returns a single random number.
pub const fn rand(&mut self) -> u64 {
const MUL: u64 = 15750249268501108917;
let state: u128 = self.state;
self.state = state.wrapping_mul(MUL as u128).wrapping_add(self.inc);
let mut hi: u64 = (state >> 64) as u64;
let lo: u64 = (state | 1) as u64;
hi ^= hi >> 32;
hi &= MUL;
hi ^= hi >> 48;
hi = hi.wrapping_mul(lo);
hi
}
/// Generate array of random numbers, based on a seed.
///
/// # Returns
///
/// A tuple with the random number array, and the RNG state afterwards so you can reuse it in later
/// calls (otherwise you'll get the same result if you're using the same seed.)
///
/// # Example
///
///```rust
/// use chess_inator::random::Pcg64Random;
///
/// // generate 3 random numbers
/// const ARR: [u64; 3] = Pcg64Random::new(123456).random_arr_64();
/// assert_eq!(ARR, [4526545874411451611, 1124465636717751929, 12699417402402334336])
///```
pub const fn random_arr_64<const N: usize>(&mut self) -> [u64; N] {
let mut ret = [0; N];
let mut i = 0;
while i < N {
let num = self.rand();
ret[i] = num;
i += 1;
}
ret
}
/// Generate 2D array of random numbers based on a seed.
pub const fn random_arr_2d_64<const N: usize, const M: usize>(&mut self) -> [[u64; N]; M] {
let mut ret = [[0; N]; M];
let mut i = 0;
while i < M {
ret[i] = self.random_arr_64();
i += 1;
}
ret
}
ret
}

View File

@ -14,7 +14,8 @@ Copyright © 2024 dogeystamp <dogeystamp@disroot.org>
//! Game-tree search.
use crate::eval::{Eval, EvalInt};
use crate::movegen::{Move, MoveGen, ToUCIAlgebraic};
use crate::hash::ZobristTable;
use crate::movegen::{Move, MoveGen};
use crate::{Board, Piece};
use std::cmp::max;
use std::sync::mpsc;
@ -154,6 +155,7 @@ fn minmax(
depth: usize,
alpha: Option<EvalInt>,
beta: Option<EvalInt>,
cache: &mut TranspositionTableOpt,
) -> (Vec<Move>, SearchEval) {
// default to worst, then gradually improve
let mut alpha = alpha.unwrap_or(EVAL_WORST);
@ -165,7 +167,6 @@ fn minmax(
return (Vec::new(), SearchEval::Centipawns(eval));
}
// sort moves by decreasing priority
let mut mvs: Vec<_> = board
.gen_moves()
.into_iter()
@ -173,6 +174,15 @@ fn minmax(
.into_iter()
.map(|mv| (move_priority(board, &mv), mv))
.collect();
// remember the prior best move
if let Some(cache) = cache {
if let Some(entry) = &cache[board.zobrist] {
mvs.push((EVAL_BEST, entry.best_move));
}
}
// sort moves by decreasing priority
mvs.sort_unstable_by_key(|mv| -mv.0);
let mut abs_best = SearchEval::Centipawns(EVAL_WORST);
@ -190,7 +200,8 @@ fn minmax(
for (_priority, mv) in mvs {
let anti_mv = mv.make(board);
let (continuation, score) = minmax(board, config, depth - 1, Some(-beta), Some(-alpha));
let (continuation, score) =
minmax(board, config, depth - 1, Some(-beta), Some(-alpha), cache);
let abs_score = score.increment();
if abs_score > abs_best {
abs_best = abs_score;
@ -210,8 +221,11 @@ fn minmax(
}
}
if let Some(mv) = best_move {
best_continuation.push(mv);
if let Some(best_move) = best_move {
best_continuation.push(best_move);
if let Some(cache) = cache {
cache[board.zobrist] = Some(TranspositionEntry { best_move });
}
}
(best_continuation, abs_best)
@ -224,14 +238,24 @@ pub enum InterfaceMsg {
type InterfaceRx = mpsc::Receiver<InterfaceMsg>;
#[derive(Clone, Copy, Debug)]
pub struct TranspositionEntry {
// best move found last time
best_move: Move,
}
pub type TranspositionTable = ZobristTable<TranspositionEntry>;
type TranspositionTableOpt = Option<TranspositionTable>;
/// Iteratively deepen search until it is stopped.
fn iter_deep(
board: &mut Board,
config: &SearchConfig,
interface: Option<InterfaceRx>,
cache: &mut TranspositionTableOpt,
) -> (Vec<Move>, SearchEval) {
for depth in 1..=config.depth {
let (line, eval) = minmax(board, config, depth, None, None);
let (line, eval) = minmax(board, config, depth, None, None, cache);
if let Some(ref rx) = interface {
match rx.try_recv() {
Ok(msg) => match msg {
@ -243,7 +267,7 @@ fn iter_deep(
},
}
} else if depth == config.depth - 1 {
return (line, eval)
return (line, eval);
}
}
panic!("iterative deepening did not search at all")
@ -254,9 +278,10 @@ pub fn best_line(
board: &mut Board,
config: Option<SearchConfig>,
interface: Option<InterfaceRx>,
cache: &mut TranspositionTableOpt,
) -> (Vec<Move>, SearchEval) {
let config = config.unwrap_or_default();
let (line, eval) = iter_deep(board, &config, interface);
let (line, eval) = iter_deep(board, &config, interface, cache);
(line, eval)
}
@ -265,8 +290,9 @@ pub fn best_move(
board: &mut Board,
config: Option<SearchConfig>,
interface: Option<InterfaceRx>,
cache: &mut TranspositionTableOpt,
) -> Option<Move> {
let (line, _eval) = best_line(board, Some(config.unwrap_or_default()), interface);
let (line, _eval) = best_line(board, Some(config.unwrap_or_default()), interface, cache);
line.last().copied()
}
@ -292,6 +318,7 @@ mod tests {
depth: 3,
}),
None,
&mut None,
)
.unwrap();
@ -304,6 +331,7 @@ mod tests {
depth: 3,
}),
None,
&mut None,
)
.unwrap();