diff --git a/src/movegen.rs b/src/movegen.rs index bc7836e..de3b0a9 100644 --- a/src/movegen.rs +++ b/src/movegen.rs @@ -17,10 +17,7 @@ pub struct Node { impl Default for Node { fn default() -> Self { - Node { - pos: BoardState::from_fen(START_POSITION).expect("Starting FEN should be valid"), - prev: None, - } + Node::new(BoardState::from_fen(START_POSITION).expect("Starting FEN should be valid")) } } @@ -36,6 +33,13 @@ impl Node { panic!("unmake should not be called on root node"); } } + + pub fn new(board: BoardState) -> Self { + Node { + pos: board, + prev: None, + } + } } /// Piece enum specifically for promotions. @@ -82,6 +86,9 @@ impl Move { fn make_unlinked(self, old_pos: BoardState) -> BoardState { let mut new_pos = old_pos; + // reset en passant + new_pos.ep_square = None; + if old_pos.turn == Color::Black { new_pos.full_moves += 1; } @@ -130,8 +137,8 @@ impl Move { let pc_dest: Option = new_pos.get_piece(self.dest); - let (src_row, src_col) = self.src.to_row_col(); - let (dest_row, dest_col) = self.dest.to_row_col(); + let (src_row, src_col) = self.src.to_row_col_signed(); + let (dest_row, dest_col) = self.dest.to_row_col_signed(); if matches!(pc_src.pc, Piece::Pawn) { // pawn moves are irreversible @@ -139,33 +146,32 @@ impl Move { // set en-passant target square if src_row.abs_diff(dest_row) == 2 { - let new_idx = match pc_src.col { - Color::White => self.src.0 + BOARD_WIDTH, - Color::Black => self.src.0 - BOARD_WIDTH, - }; - new_pos.ep_square = Some( - Square::try_from(new_idx).expect("En-passant target should be valid."), - ) - } else { - new_pos.ep_square = None; - if pc_dest.is_none() && src_col != dest_col { - // we took en passant - debug_assert!(src_row.abs_diff(dest_row) == 1); - debug_assert_eq!(self.dest, old_pos.ep_square.unwrap()); - // square to actually capture at - let ep_capture = Square::try_from(match pc_src.col { - Color::White => self.dest.0 - BOARD_WIDTH, - Color::Black => self.dest.0 + BOARD_WIDTH, - }) - .expect("En-passant capture square should be valid."); - new_pos - .del_piece(ep_capture) - .expect("En-passant capture square should have piece."); - } + let ep_col = src_col; + debug_assert_eq!(src_col, dest_col); + let ep_row = dest_row + + match pc_src.col { + Color::White => -1, + Color::Black => 1, + }; + let ep_targ = Square::from_row_col_signed(ep_row, ep_col) + .expect("En-passant target should be valid."); + new_pos.ep_square = Some(ep_targ) + } else if pc_dest.is_none() && src_col != dest_col { + // we took en passant + debug_assert!(src_row.abs_diff(dest_row) == 1); + debug_assert_eq!(self.dest, old_pos.ep_square.unwrap()); + // square to actually capture at + let ep_capture = Square::try_from(match pc_src.col { + Color::White => self.dest.0 - BOARD_WIDTH, + Color::Black => self.dest.0 + BOARD_WIDTH, + }) + .expect("En-passant capture square should be valid"); + new_pos.del_piece(ep_capture).unwrap_or_else(|_| { + panic!("En-passant capture square should have piece. Position '{}', move {:?}", old_pos.to_fen(), self) + }); } } else { new_pos.half_moves += 1; - new_pos.ep_square = None; } if pc_dest.is_some() { @@ -186,16 +192,16 @@ impl Move { let rook_src_col = if src_col > dest_col { 0 } else { - BOARD_WIDTH - 1 + isize::try_from(BOARD_WIDTH).unwrap() - 1 }; let rook_dest_col = if src_col > dest_col { dest_col + 1 } else { dest_col - 1 }; - let rook_src = Square::from_row_col(rook_row, rook_src_col) + let rook_src = Square::from_row_col_signed(rook_row, rook_src_col) .expect("rook castling src square should be valid"); - let rook_dest = Square::from_row_col(rook_row, rook_dest_col) + let rook_dest = Square::from_row_col_signed(rook_row, rook_dest_col) .expect("rook castling dest square should be valid"); debug_assert!(new_pos.get_piece(rook_src).is_some(), "rook castling src square has no rook (move: {rook_src} -> {rook_dest})"); new_pos.move_piece(rook_src, rook_dest); @@ -671,7 +677,7 @@ impl LegalMoveGen for Node { } /// How many nodes at depth N can be reached from this position. -fn perft(depth: usize, node: &Rc) -> usize { +pub fn perft(depth: usize, node: &Rc) -> usize { if depth == 0 { return 1; }; @@ -1162,10 +1168,7 @@ mod tests { expected_moves.sort_unstable(); let expected_moves = expected_moves; - let node = Node { - pos: board, - prev: None, - }; + let node = Node::new(board); let mut moves: Vec = node.gen_moves().into_iter().collect(); moves.sort_unstable(); @@ -1289,10 +1292,7 @@ mod tests { // make move eprintln!("Starting test case {i}, make move."); - let mut node = Rc::new(Node { - pos: BoardState::from_fen(start_pos).unwrap(), - prev: None, - }); + let mut node = Rc::new(Node::new(BoardState::from_fen(start_pos).unwrap())); for (move_str, expect_fen) in moves { let mv = Move::from_uci_algebraic(move_str).unwrap(); eprintln!("Moving {move_str}."); @@ -1320,39 +1320,51 @@ mod tests { // https://www.chessprogramming.org/Perft_Results let test_cases = [ ( + // fen START_POSITION, - // Only up to depth 4 because the engine isn't good enough to do this often - vec![1, 20, 400, 8_902, 197_281], + // expected perft values + vec![1, 20, 400, 8_902, 197_281, 4_865_609, 119_060_324], + // limit depth when not under `cargo test --release` (unoptimized build too slow) + 4, ), ( "r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 1", - vec![1, 48, 2_039, 97862], + vec![1, 48, 2_039, 97_862, 4_085_603], + 3, ), ( "8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 1", - vec![1, 14, 191, 2_812, 43_238], + vec![1, 14, 191, 2_812, 43_238, 674_624, 11_030_083], + 4, ), ( "r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", - vec![1, 6, 264, 9467], + vec![1, 6, 264, 9467, 422_333, 15_833_292], + 3, ), ( "rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", - vec![1, 44, 1_486, 62_379], + vec![1, 44, 1_486, 62_379, 2_103_487, 89_941_194], + 3, ), ( "r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", - vec![1, 46, 2_079, 89_890], + vec![1, 46, 2_079, 89_890, 3_894_594], + 3, ), ]; - for (fen, expected_values) in test_cases { - let root_node = Rc::new(Node { - pos: BoardState::from_fen(fen).unwrap(), - prev: None, - }); + for (fen, expected_values, debug_limit_depth) in test_cases { + let root_node = Rc::new(Node::new(BoardState::from_fen(fen).unwrap())); for (depth, expected) in expected_values.iter().enumerate() { - assert_eq!(perft(depth, &root_node), *expected, "failed perft depth {depth} on position '{fen}'"); + eprintln!("running perft depth {depth} on position '{fen}'"); + #[cfg(debug_assertions)] + { + if depth > debug_limit_depth { + break; + } + } + assert_eq!(perft(depth, &root_node), *expected,); } } }