fix: en passant bug

also added more depth on some perft positions
This commit is contained in:
dogeystamp 2024-10-22 17:11:04 -04:00
parent 61e5a6114b
commit 09551e802e

View File

@ -17,10 +17,7 @@ pub struct Node {
impl Default for Node {
fn default() -> Self {
Node {
pos: BoardState::from_fen(START_POSITION).expect("Starting FEN should be valid"),
prev: None,
}
Node::new(BoardState::from_fen(START_POSITION).expect("Starting FEN should be valid"))
}
}
@ -36,6 +33,13 @@ impl Node {
panic!("unmake should not be called on root node");
}
}
pub fn new(board: BoardState) -> Self {
Node {
pos: board,
prev: None,
}
}
}
/// Piece enum specifically for promotions.
@ -82,6 +86,9 @@ impl Move {
fn make_unlinked(self, old_pos: BoardState) -> BoardState {
let mut new_pos = old_pos;
// reset en passant
new_pos.ep_square = None;
if old_pos.turn == Color::Black {
new_pos.full_moves += 1;
}
@ -130,8 +137,8 @@ impl Move {
let pc_dest: Option<ColPiece> = new_pos.get_piece(self.dest);
let (src_row, src_col) = self.src.to_row_col();
let (dest_row, dest_col) = self.dest.to_row_col();
let (src_row, src_col) = self.src.to_row_col_signed();
let (dest_row, dest_col) = self.dest.to_row_col_signed();
if matches!(pc_src.pc, Piece::Pawn) {
// pawn moves are irreversible
@ -139,33 +146,32 @@ impl Move {
// set en-passant target square
if src_row.abs_diff(dest_row) == 2 {
let new_idx = match pc_src.col {
Color::White => self.src.0 + BOARD_WIDTH,
Color::Black => self.src.0 - BOARD_WIDTH,
};
new_pos.ep_square = Some(
Square::try_from(new_idx).expect("En-passant target should be valid."),
)
} else {
new_pos.ep_square = None;
if pc_dest.is_none() && src_col != dest_col {
// we took en passant
debug_assert!(src_row.abs_diff(dest_row) == 1);
debug_assert_eq!(self.dest, old_pos.ep_square.unwrap());
// square to actually capture at
let ep_capture = Square::try_from(match pc_src.col {
Color::White => self.dest.0 - BOARD_WIDTH,
Color::Black => self.dest.0 + BOARD_WIDTH,
})
.expect("En-passant capture square should be valid.");
new_pos
.del_piece(ep_capture)
.expect("En-passant capture square should have piece.");
}
let ep_col = src_col;
debug_assert_eq!(src_col, dest_col);
let ep_row = dest_row
+ match pc_src.col {
Color::White => -1,
Color::Black => 1,
};
let ep_targ = Square::from_row_col_signed(ep_row, ep_col)
.expect("En-passant target should be valid.");
new_pos.ep_square = Some(ep_targ)
} else if pc_dest.is_none() && src_col != dest_col {
// we took en passant
debug_assert!(src_row.abs_diff(dest_row) == 1);
debug_assert_eq!(self.dest, old_pos.ep_square.unwrap());
// square to actually capture at
let ep_capture = Square::try_from(match pc_src.col {
Color::White => self.dest.0 - BOARD_WIDTH,
Color::Black => self.dest.0 + BOARD_WIDTH,
})
.expect("En-passant capture square should be valid");
new_pos.del_piece(ep_capture).unwrap_or_else(|_| {
panic!("En-passant capture square should have piece. Position '{}', move {:?}", old_pos.to_fen(), self)
});
}
} else {
new_pos.half_moves += 1;
new_pos.ep_square = None;
}
if pc_dest.is_some() {
@ -186,16 +192,16 @@ impl Move {
let rook_src_col = if src_col > dest_col {
0
} else {
BOARD_WIDTH - 1
isize::try_from(BOARD_WIDTH).unwrap() - 1
};
let rook_dest_col = if src_col > dest_col {
dest_col + 1
} else {
dest_col - 1
};
let rook_src = Square::from_row_col(rook_row, rook_src_col)
let rook_src = Square::from_row_col_signed(rook_row, rook_src_col)
.expect("rook castling src square should be valid");
let rook_dest = Square::from_row_col(rook_row, rook_dest_col)
let rook_dest = Square::from_row_col_signed(rook_row, rook_dest_col)
.expect("rook castling dest square should be valid");
debug_assert!(new_pos.get_piece(rook_src).is_some(), "rook castling src square has no rook (move: {rook_src} -> {rook_dest})");
new_pos.move_piece(rook_src, rook_dest);
@ -671,7 +677,7 @@ impl LegalMoveGen for Node {
}
/// How many nodes at depth N can be reached from this position.
fn perft(depth: usize, node: &Rc<Node>) -> usize {
pub fn perft(depth: usize, node: &Rc<Node>) -> usize {
if depth == 0 {
return 1;
};
@ -1162,10 +1168,7 @@ mod tests {
expected_moves.sort_unstable();
let expected_moves = expected_moves;
let node = Node {
pos: board,
prev: None,
};
let node = Node::new(board);
let mut moves: Vec<Move> = node.gen_moves().into_iter().collect();
moves.sort_unstable();
@ -1289,10 +1292,7 @@ mod tests {
// make move
eprintln!("Starting test case {i}, make move.");
let mut node = Rc::new(Node {
pos: BoardState::from_fen(start_pos).unwrap(),
prev: None,
});
let mut node = Rc::new(Node::new(BoardState::from_fen(start_pos).unwrap()));
for (move_str, expect_fen) in moves {
let mv = Move::from_uci_algebraic(move_str).unwrap();
eprintln!("Moving {move_str}.");
@ -1320,39 +1320,51 @@ mod tests {
// https://www.chessprogramming.org/Perft_Results
let test_cases = [
(
// fen
START_POSITION,
// Only up to depth 4 because the engine isn't good enough to do this often
vec![1, 20, 400, 8_902, 197_281],
// expected perft values
vec![1, 20, 400, 8_902, 197_281, 4_865_609, 119_060_324],
// limit depth when not under `cargo test --release` (unoptimized build too slow)
4,
),
(
"r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 1",
vec![1, 48, 2_039, 97862],
vec![1, 48, 2_039, 97_862, 4_085_603],
3,
),
(
"8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 1",
vec![1, 14, 191, 2_812, 43_238],
vec![1, 14, 191, 2_812, 43_238, 674_624, 11_030_083],
4,
),
(
"r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1",
vec![1, 6, 264, 9467],
vec![1, 6, 264, 9467, 422_333, 15_833_292],
3,
),
(
"rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8",
vec![1, 44, 1_486, 62_379],
vec![1, 44, 1_486, 62_379, 2_103_487, 89_941_194],
3,
),
(
"r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10",
vec![1, 46, 2_079, 89_890],
vec![1, 46, 2_079, 89_890, 3_894_594],
3,
),
];
for (fen, expected_values) in test_cases {
let root_node = Rc::new(Node {
pos: BoardState::from_fen(fen).unwrap(),
prev: None,
});
for (fen, expected_values, debug_limit_depth) in test_cases {
let root_node = Rc::new(Node::new(BoardState::from_fen(fen).unwrap()));
for (depth, expected) in expected_values.iter().enumerate() {
assert_eq!(perft(depth, &root_node), *expected, "failed perft depth {depth} on position '{fen}'");
eprintln!("running perft depth {depth} on position '{fen}'");
#[cfg(debug_assertions)]
{
if depth > debug_limit_depth {
break;
}
}
assert_eq!(perft(depth, &root_node), *expected,);
}
}
}