diff --git a/src/movegen.rs b/src/movegen.rs index 1a2b788..65ae136 100644 --- a/src/movegen.rs +++ b/src/movegen.rs @@ -8,7 +8,7 @@ use std::rc::Rc; /// Game tree node. #[derive(Clone, Debug)] -struct Node { +pub struct Node { /// Immutable position data. pos: BoardState, /// Backlink to previous node. @@ -24,6 +24,18 @@ impl Default for Node { } } +impl Node { + /// Undo move. + /// + /// Intended usage is to always keep an Rc to the current node, and overwrite it with the + /// result of unmake. + pub fn unmake(&self) -> Rc { + self.prev + .clone() + .expect("unmake should not be called on root node") + } +} + /// Piece enum specifically for promotions. #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] enum PromotePiece { @@ -66,10 +78,10 @@ impl Move { /// /// Old position is saved in a backlink. /// No checking is done to verify even pseudo-legality of the move. - pub fn make(self, old_node: Node) -> Node { + pub fn make(self, old_node: Rc) -> Node { let old_pos = old_node.pos; let mut node = Node { - prev: Some(Rc::new(old_node)), + prev: Some(old_node), pos: old_pos, }; @@ -540,6 +552,75 @@ pub trait LegalMoveGen { fn gen_moves(&self) -> impl IntoIterator; } +/// Is a given player in check? +fn is_check(board: &BoardState, pl: Color) -> bool { + for src in board.pl(pl).board(Piece::King).into_iter() { + macro_rules! detect_checker { + ($dirs: ident, $pc: ident, $keep_going: expr) => { + for dir in $dirs.into_iter() { + let (mut r, mut c) = src.to_row_col_signed(); + loop { + let (nr, nc) = (r + dir.0, c + dir.1); + if let Ok(sq) = Square::from_row_col_signed(nr, nc) { + if let Some(pc) = board.get_piece(sq) { + if pc.pc == Piece::$pc && pc.col != pl { + return true; + } else { + break; + } + } + } else { + break; + } + if (!($keep_going)) { + break; + } + r = nr; + c = nc; + } + } + }; + } + + let dirs_white_pawn = [(-1, 1), (-1, -1)]; + let dirs_black_pawn = [(1, 1), (1, -1)]; + + detect_checker!(DIRS_STAR, Queen, true); + detect_checker!(DIRS_DIAG, Bishop, true); + detect_checker!(DIRS_STRAIGHT, Rook, true); + detect_checker!(DIRS_STAR, King, false); + detect_checker!(DIRS_KNIGHT, Knight, false); + match pl { + Color::White => detect_checker!(dirs_black_pawn, Pawn, false), + Color::Black => detect_checker!(dirs_white_pawn, Pawn, false), + } + } + false +} + +impl LegalMoveGen for Node { + fn gen_moves(&self) -> impl IntoIterator { + self.pos + .gen_pseudo_moves() + .into_iter() + .filter(|mv| { + // disallow friendly fire + let src_pc = self + .pos + .get_piece(mv.src) + .expect("move source should have piece"); + if let Some(dest_pc) = self.pos.get_piece(mv.dest) { + return dest_pc.col != src_pc.col; + } + true + }) + .filter(|mv| { + // disallow moving into check + let new_node = mv.make(self.clone().into()); + !is_check(&new_node.pos, self.pos.turn) + }) + } +} #[cfg(test)] mod tests { use super::*; @@ -801,6 +882,116 @@ mod tests { } } + /// Test check checker. + #[test] + fn test_is_check() { + let check_cases = [ + "3r4/8/8/3K4/8/8/8/8 b - - 0 1", + "8/8/8/3K3r/8/8/8/8 b - - 0 1", + "8/8/8/3K4/8/8/8/3r4 b - - 0 1", + "8/8/8/r2K4/8/8/8/8 b - - 0 1", + "1b6/8/8/3K4/1r6/8/8/k6b b - - 0 1", + "1b6/8/4p3/3K4/1r6/8/8/k5b1 b - - 0 1", + "1b6/4n3/3p4/3K4/1r6/8/8/k5b1 b - - 0 1", + "1b6/2n5/3p4/3K4/1r6/8/8/k5b1 b - - 0 1", + "1b6/8/3p4/3K4/1r3n2/8/8/k5b1 b - - 0 1", + "1b6/8/3p4/3K4/1r1k4/5n2/8/6b1 b - - 0 1", + "8/8/8/4b3/r2b4/rnq5/PP6/KRrr4 w - - 0 1", + ] + .map(|tc| (tc, true)); + + let not_check_cases = [ + "1b6/8/3p4/3K4/1r6/5n2/8/k5b1 b - - 0 1", + "1bqnb3/3q1n2/3p4/3K4/1r6/2q1qn2/8/k5b1 b - - 0 1", + "1bqnb1q1/3q1n2/3p4/1qbKp1q1/1r1b4/2q1qn2/8/k5b1 b - - 0 1", + "8/8/8/4b3/r2b4/r1q5/PP6/KRrr4 w - - 0 1", + ] + .map(|tc| (tc, false)); + + let all_cases = check_cases.iter().chain(¬_check_cases); + for (fen, expected) in all_cases { + let board = BoardState::from_fen(fen).unwrap(); + assert_eq!( + is_check(&board, Color::White), + *expected, + "failed on {}", + fen + ); + } + } + + /// Test legal movegen through contrived positions. + #[test] + fn test_legal_movegen() { + let test_cases = [ + // rook friendly fire test + ( + // start position + "8/8/8/8/8/8/rr6/RRr5 w - - 0 1", + // expected moves + vec![ + ( + // source piece + "a1", + // destination squares + vec!["a2"], + MoveType::Normal, + ), + ( + // source piece + "b1", + // destination squares + vec!["b2", "c1"], + MoveType::Normal, + ), + ], + ), + // check test + ( + "1bqnb1q1/3q1n2/q2p4/2bKp1q1/1r1b4/1q3n2/8/k2q2b1 w - - 0 1", + vec![("d5", vec!["e4"], MoveType::Normal)], + ), + // check test + ( + "1b1nb1q1/q2q1n2/q2p4/2bKp1q1/1r1p4/1q3n2/8/k2q2b1 w - - 0 1", + vec![("d5", vec!["e4"], MoveType::Normal)], + ), + ]; + + for (i, (fen, expected)) in test_cases.iter().enumerate() { + let node = Node { + pos: BoardState::from_fen(fen).unwrap(), + prev: None, + }; + + let mut moves: Vec = node.gen_moves().into_iter().collect(); + moves.sort_unstable(); + let moves = moves; + + let mut expected_moves = expected + .iter() + .map(|(src, dests, move_type)| { + let src = src.parse::().unwrap(); + let dests = dests + .iter() + .map(|x| x.parse::()) + .map(|x| x.unwrap()); + dests.map(move |dest| Move { + src, + dest, + move_type: *move_type, + }) + }) + .flatten() + .collect::>(); + + expected_moves.sort_unstable(); + let expected_moves = expected_moves; + + assert_eq!(moves, expected_moves, "failed test case {i} ({fen})"); + } + } + /// Test that make move and unmake move work as expected. /// /// Ensure that: @@ -915,25 +1106,24 @@ mod tests { // make move eprintln!("Starting test case {i}, make move."); - let mut node = Node { + let mut node = Rc::new(Node { pos: BoardState::from_fen(start_pos).unwrap(), prev: None, - }; + }); for (move_str, expect_fen) in moves { let mv = Move::from_uci_algebraic(move_str).unwrap(); eprintln!("Moving {move_str}."); - node = mv.make(node); + node = mv.make(node).into(); assert_eq!(node.pos.to_fen(), expect_fen.to_string()) } // unmake move eprintln!("Starting test case {i}, unmake move."); - let mut cur_node = Rc::new(node.clone()); for (_, expect_fen) in moves.iter().rev().chain([("", *start_pos)].iter()) { eprintln!("{}", expect_fen); - assert_eq!(*cur_node.pos.to_fen(), expect_fen.to_string()); + assert_eq!(*node.pos.to_fen(), expect_fen.to_string()); if *expect_fen != *start_pos { - cur_node = cur_node.prev.clone().unwrap(); + node = node.unmake(); } } }