fix: en passant bug

also added more depth on some perft positions
This commit is contained in:
dogeystamp 2024-10-22 17:11:04 -04:00
parent 61e5a6114b
commit 09551e802e

View File

@ -17,10 +17,7 @@ pub struct Node {
impl Default for Node { impl Default for Node {
fn default() -> Self { fn default() -> Self {
Node { Node::new(BoardState::from_fen(START_POSITION).expect("Starting FEN should be valid"))
pos: BoardState::from_fen(START_POSITION).expect("Starting FEN should be valid"),
prev: None,
}
} }
} }
@ -36,6 +33,13 @@ impl Node {
panic!("unmake should not be called on root node"); panic!("unmake should not be called on root node");
} }
} }
pub fn new(board: BoardState) -> Self {
Node {
pos: board,
prev: None,
}
}
} }
/// Piece enum specifically for promotions. /// Piece enum specifically for promotions.
@ -82,6 +86,9 @@ impl Move {
fn make_unlinked(self, old_pos: BoardState) -> BoardState { fn make_unlinked(self, old_pos: BoardState) -> BoardState {
let mut new_pos = old_pos; let mut new_pos = old_pos;
// reset en passant
new_pos.ep_square = None;
if old_pos.turn == Color::Black { if old_pos.turn == Color::Black {
new_pos.full_moves += 1; new_pos.full_moves += 1;
} }
@ -130,8 +137,8 @@ impl Move {
let pc_dest: Option<ColPiece> = new_pos.get_piece(self.dest); let pc_dest: Option<ColPiece> = new_pos.get_piece(self.dest);
let (src_row, src_col) = self.src.to_row_col(); let (src_row, src_col) = self.src.to_row_col_signed();
let (dest_row, dest_col) = self.dest.to_row_col(); let (dest_row, dest_col) = self.dest.to_row_col_signed();
if matches!(pc_src.pc, Piece::Pawn) { if matches!(pc_src.pc, Piece::Pawn) {
// pawn moves are irreversible // pawn moves are irreversible
@ -139,33 +146,32 @@ impl Move {
// set en-passant target square // set en-passant target square
if src_row.abs_diff(dest_row) == 2 { if src_row.abs_diff(dest_row) == 2 {
let new_idx = match pc_src.col { let ep_col = src_col;
Color::White => self.src.0 + BOARD_WIDTH, debug_assert_eq!(src_col, dest_col);
Color::Black => self.src.0 - BOARD_WIDTH, let ep_row = dest_row
}; + match pc_src.col {
new_pos.ep_square = Some( Color::White => -1,
Square::try_from(new_idx).expect("En-passant target should be valid."), Color::Black => 1,
) };
} else { let ep_targ = Square::from_row_col_signed(ep_row, ep_col)
new_pos.ep_square = None; .expect("En-passant target should be valid.");
if pc_dest.is_none() && src_col != dest_col { new_pos.ep_square = Some(ep_targ)
// we took en passant } else if pc_dest.is_none() && src_col != dest_col {
debug_assert!(src_row.abs_diff(dest_row) == 1); // we took en passant
debug_assert_eq!(self.dest, old_pos.ep_square.unwrap()); debug_assert!(src_row.abs_diff(dest_row) == 1);
// square to actually capture at debug_assert_eq!(self.dest, old_pos.ep_square.unwrap());
let ep_capture = Square::try_from(match pc_src.col { // square to actually capture at
Color::White => self.dest.0 - BOARD_WIDTH, let ep_capture = Square::try_from(match pc_src.col {
Color::Black => self.dest.0 + BOARD_WIDTH, Color::White => self.dest.0 - BOARD_WIDTH,
}) Color::Black => self.dest.0 + BOARD_WIDTH,
.expect("En-passant capture square should be valid."); })
new_pos .expect("En-passant capture square should be valid");
.del_piece(ep_capture) new_pos.del_piece(ep_capture).unwrap_or_else(|_| {
.expect("En-passant capture square should have piece."); panic!("En-passant capture square should have piece. Position '{}', move {:?}", old_pos.to_fen(), self)
} });
} }
} else { } else {
new_pos.half_moves += 1; new_pos.half_moves += 1;
new_pos.ep_square = None;
} }
if pc_dest.is_some() { if pc_dest.is_some() {
@ -186,16 +192,16 @@ impl Move {
let rook_src_col = if src_col > dest_col { let rook_src_col = if src_col > dest_col {
0 0
} else { } else {
BOARD_WIDTH - 1 isize::try_from(BOARD_WIDTH).unwrap() - 1
}; };
let rook_dest_col = if src_col > dest_col { let rook_dest_col = if src_col > dest_col {
dest_col + 1 dest_col + 1
} else { } else {
dest_col - 1 dest_col - 1
}; };
let rook_src = Square::from_row_col(rook_row, rook_src_col) let rook_src = Square::from_row_col_signed(rook_row, rook_src_col)
.expect("rook castling src square should be valid"); .expect("rook castling src square should be valid");
let rook_dest = Square::from_row_col(rook_row, rook_dest_col) let rook_dest = Square::from_row_col_signed(rook_row, rook_dest_col)
.expect("rook castling dest square should be valid"); .expect("rook castling dest square should be valid");
debug_assert!(new_pos.get_piece(rook_src).is_some(), "rook castling src square has no rook (move: {rook_src} -> {rook_dest})"); debug_assert!(new_pos.get_piece(rook_src).is_some(), "rook castling src square has no rook (move: {rook_src} -> {rook_dest})");
new_pos.move_piece(rook_src, rook_dest); new_pos.move_piece(rook_src, rook_dest);
@ -671,7 +677,7 @@ impl LegalMoveGen for Node {
} }
/// How many nodes at depth N can be reached from this position. /// How many nodes at depth N can be reached from this position.
fn perft(depth: usize, node: &Rc<Node>) -> usize { pub fn perft(depth: usize, node: &Rc<Node>) -> usize {
if depth == 0 { if depth == 0 {
return 1; return 1;
}; };
@ -1162,10 +1168,7 @@ mod tests {
expected_moves.sort_unstable(); expected_moves.sort_unstable();
let expected_moves = expected_moves; let expected_moves = expected_moves;
let node = Node { let node = Node::new(board);
pos: board,
prev: None,
};
let mut moves: Vec<Move> = node.gen_moves().into_iter().collect(); let mut moves: Vec<Move> = node.gen_moves().into_iter().collect();
moves.sort_unstable(); moves.sort_unstable();
@ -1289,10 +1292,7 @@ mod tests {
// make move // make move
eprintln!("Starting test case {i}, make move."); eprintln!("Starting test case {i}, make move.");
let mut node = Rc::new(Node { let mut node = Rc::new(Node::new(BoardState::from_fen(start_pos).unwrap()));
pos: BoardState::from_fen(start_pos).unwrap(),
prev: None,
});
for (move_str, expect_fen) in moves { for (move_str, expect_fen) in moves {
let mv = Move::from_uci_algebraic(move_str).unwrap(); let mv = Move::from_uci_algebraic(move_str).unwrap();
eprintln!("Moving {move_str}."); eprintln!("Moving {move_str}.");
@ -1320,39 +1320,51 @@ mod tests {
// https://www.chessprogramming.org/Perft_Results // https://www.chessprogramming.org/Perft_Results
let test_cases = [ let test_cases = [
( (
// fen
START_POSITION, START_POSITION,
// Only up to depth 4 because the engine isn't good enough to do this often // expected perft values
vec![1, 20, 400, 8_902, 197_281], vec![1, 20, 400, 8_902, 197_281, 4_865_609, 119_060_324],
// limit depth when not under `cargo test --release` (unoptimized build too slow)
4,
), ),
( (
"r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 1", "r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 1",
vec![1, 48, 2_039, 97862], vec![1, 48, 2_039, 97_862, 4_085_603],
3,
), ),
( (
"8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 1", "8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 1",
vec![1, 14, 191, 2_812, 43_238], vec![1, 14, 191, 2_812, 43_238, 674_624, 11_030_083],
4,
), ),
( (
"r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", "r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1",
vec![1, 6, 264, 9467], vec![1, 6, 264, 9467, 422_333, 15_833_292],
3,
), ),
( (
"rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", "rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8",
vec![1, 44, 1_486, 62_379], vec![1, 44, 1_486, 62_379, 2_103_487, 89_941_194],
3,
), ),
( (
"r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", "r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10",
vec![1, 46, 2_079, 89_890], vec![1, 46, 2_079, 89_890, 3_894_594],
3,
), ),
]; ];
for (fen, expected_values) in test_cases { for (fen, expected_values, debug_limit_depth) in test_cases {
let root_node = Rc::new(Node { let root_node = Rc::new(Node::new(BoardState::from_fen(fen).unwrap()));
pos: BoardState::from_fen(fen).unwrap(),
prev: None,
});
for (depth, expected) in expected_values.iter().enumerate() { for (depth, expected) in expected_values.iter().enumerate() {
assert_eq!(perft(depth, &root_node), *expected, "failed perft depth {depth} on position '{fen}'"); eprintln!("running perft depth {depth} on position '{fen}'");
#[cfg(debug_assertions)]
{
if depth > debug_limit_depth {
break;
}
}
assert_eq!(perft(depth, &root_node), *expected,);
} }
} }
} }