diff --git a/src/movegen.rs b/src/movegen.rs index 9b1d063..612bb34 100644 --- a/src/movegen.rs +++ b/src/movegen.rs @@ -582,6 +582,24 @@ impl GenAttackers for Board { } } +/// Options for movegen. +pub struct MoveGenConfig { + /// Restricts movegen to only output capture moves. + /// + /// This is more efficient than filtering captures after generating moves. + captures_only: bool, + legality: MoveGenType, +} + +impl Default for MoveGenConfig { + fn default() -> Self { + MoveGenConfig { + captures_only: false, + legality: MoveGenType::Legal, + } + } +} + #[derive(Debug, Clone, Copy)] enum MoveGenType { /// Legal move generation. @@ -590,19 +608,41 @@ enum MoveGenType { _Pseudo, } -/// Internal, slightly more general movegen interface +/// Internal movegen interface with more options trait MoveGenInternal { - fn gen_moves_general(&mut self, gen_type: MoveGenType) -> impl IntoIterator; + fn gen_moves_general(&mut self, config: MoveGenConfig) -> impl IntoIterator; } pub trait MoveGen { /// Legal move generation. fn gen_moves(&mut self) -> impl IntoIterator; + + /// Pseudo-legal move generation (see `MoveGenType::_Pseudo` for more information). + fn gen_pseudo(&mut self) -> impl IntoIterator; + + /// Legal capture generation. + fn gen_captures(&mut self) -> impl IntoIterator; } impl MoveGen for T { fn gen_moves(&mut self) -> impl IntoIterator { - self.gen_moves_general(MoveGenType::Legal) + self.gen_moves_general(MoveGenConfig::default()) + } + + fn gen_pseudo(&mut self) -> impl IntoIterator { + let config = MoveGenConfig { + legality: MoveGenType::_Pseudo, + ..Default::default() + }; + self.gen_moves_general(config) + } + + fn gen_captures(&mut self) -> impl IntoIterator { + let config = MoveGenConfig { + captures_only: true, + ..Default::default() + }; + self.gen_moves_general(config) } } @@ -651,6 +691,7 @@ fn move_slider( move_list: &mut Vec, slide_type: SliderDirection, keep_going: bool, + config: &MoveGenConfig, ) { let dirs = match slide_type { SliderDirection::Straight => DIRS_STRAIGHT.iter(), @@ -669,14 +710,23 @@ fn move_slider( r = nr; c = nc; - move_list.push(Move { - src, - dest, - move_type: MoveType::Normal, - }); + let obstructed = board.get_piece(dest).is_some(); - // stop at other pieces. - if let Some(_cap_pc) = board.get_piece(dest) { + let mut gen_move = true; + + if config.captures_only && !obstructed { + gen_move = false; + } + + if gen_move { + move_list.push(Move { + src, + dest, + move_type: MoveType::Normal, + }); + } + + if obstructed { break; } } else { @@ -714,7 +764,7 @@ fn is_legal(board: &mut Board, mv: Move) -> bool { } impl MoveGenInternal for Board { - fn gen_moves_general(&mut self, gen_type: MoveGenType) -> impl IntoIterator { + fn gen_moves_general(&mut self, config: MoveGenConfig) -> impl IntoIterator { let mut ret = Vec::new(); let pl = self[self.turn]; macro_rules! squares { @@ -724,16 +774,22 @@ impl MoveGenInternal for Board { } for sq in squares!(Rook) { - move_slider(self, sq, &mut ret, SliderDirection::Straight, true); + move_slider(self, sq, &mut ret, SliderDirection::Straight, true, &config); } for sq in squares!(Bishop) { - move_slider(self, sq, &mut ret, SliderDirection::Diagonal, true); + move_slider(self, sq, &mut ret, SliderDirection::Diagonal, true, &config); } for sq in squares!(Queen) { - move_slider(self, sq, &mut ret, SliderDirection::Star, true); + move_slider(self, sq, &mut ret, SliderDirection::Star, true, &config); } for src in squares!(King) { - move_slider(self, src, &mut ret, SliderDirection::Star, false); + move_slider(self, src, &mut ret, SliderDirection::Star, false, &config); + + if config.captures_only { + // castling can't capture + continue; + } + let (r, c) = src.to_row_col_signed(); let rights = self.castle[self.turn]; let castle_sides = [(rights.k, 2, BOARD_WIDTH as isize - 1), (rights.q, -2, 0)]; @@ -845,6 +901,10 @@ impl MoveGenInternal for Board { } } + if config.captures_only { + continue; + } + // single push let nc = c; let dest = match Square::from_row_col_signed(nr, nc) { @@ -881,6 +941,9 @@ impl MoveGenInternal for Board { let nr = r + dir.0; let nc = c + dir.1; if let Ok(dest) = Square::from_row_col_signed(nr, nc) { + if config.captures_only && self.get_piece(dest).is_none() { + continue; + } ret.push(Move { src, dest, @@ -889,7 +952,7 @@ impl MoveGenInternal for Board { } } } - ret.retain(move |mv| match gen_type { + ret.retain(move |mv| match config.legality { MoveGenType::Legal => is_legal(self, *mv), MoveGenType::_Pseudo => true, }); @@ -1243,10 +1306,7 @@ mod tests { let all_cases = [augmented_test_cases, test_cases].concat(); for (mut board, expected_moves) in all_cases { - let mut moves: Vec = board - .gen_moves_general(MoveGenType::_Pseudo) - .into_iter() - .collect(); + let mut moves: Vec = board.gen_pseudo().into_iter().collect(); moves.sort_unstable(); let moves = moves; @@ -1604,4 +1664,42 @@ mod tests { assert_eq!(attackers, expected); } } + + #[test] + fn test_capture_movegen() { + let test_cases = [( + // fen + "8/3q4/5N2/8/8/8/8/3K4 w - - 0 1", + // expected moves generated + "f6d7", + ), + ( + "8/8/8/3pP3/2K5/8/8/8 w - d6 0 1", + // holy hell + "e5d6 c4d5", + ), + ( + "8/2q5/3K4/8/8/8/8/8 w - - 0 1", + "d6c7", + ), + ( + "2Q5/3r2R1/2B1PN2/8/3K4/8/8/8 w - - 0 1", + "c6d7 e6d7 c8d7 f6d7 g7d7", + ), + ]; + + for (fen, expected) in test_cases { + let mut board = Board::from_fen(fen).unwrap(); + let mut moves = board.gen_captures().into_iter().collect::>(); + moves.sort(); + let mut expected = expected + .split_whitespace() + .map(Move::from_uci_algebraic) + .map(|x| x.unwrap()) + .collect::>(); + expected.sort(); + + assert_eq!(moves, expected, "failed '{}'", fen); + } + } }