From 16141e851ca74027cd49d9fa33e925fdd85446fa Mon Sep 17 00:00:00 2001 From: dogeystamp Date: Tue, 22 Oct 2024 12:51:22 -0400 Subject: [PATCH] feat: ensure king does not castle through check --- src/fen.rs | 6 +-- src/lib.rs | 6 +-- src/movegen.rs | 144 +++++++++++++++++++++++++++++++++++++------------ 3 files changed, 117 insertions(+), 39 deletions(-) diff --git a/src/fen.rs b/src/fen.rs index b47ef09..865eebc 100644 --- a/src/fen.rs +++ b/src/fen.rs @@ -109,15 +109,15 @@ impl FromFen for BoardState { pc_char @ ('a'..='z' | 'A'..='Z') => { let pc = ColPiece::try_from(pc_char).or(bad_char!(i, c))?; + if col > 7 { + return Err(FenError::TooManyPieces(i)); + }; pos.set_piece( Square::from_row_col(real_row, col) .or(Err(FenError::InternalError(i)))?, pc, ); col += 1; - if col > 8 { - return Err(FenError::TooManyPieces(i)); - }; parser_state = FenState::Piece(row, col) } number @ '1'..='9' => { diff --git a/src/lib.rs b/src/lib.rs index d933d2f..7ffa3b1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,8 +6,6 @@ use std::str::FromStr; pub mod fen; pub mod movegen; -use fen::FromFen; - const BOARD_WIDTH: usize = 8; const BOARD_HEIGHT: usize = 8; const N_SQUARES: usize = BOARD_WIDTH * BOARD_HEIGHT; @@ -485,7 +483,7 @@ impl BoardState { } fn move_piece(&mut self, src: Square, dest: Square) { - let pc = self.del_piece(src).expect("Move source should have piece."); + let pc = self.del_piece(src).unwrap_or_else(|_| panic!("move ({src} -> {dest}) should have piece at source")); self.set_piece(dest, pc); } @@ -543,6 +541,8 @@ impl core::fmt::Display for BoardState { mod tests { use super::*; + use fen::FromFen; + #[test] fn test_square_casts() { let fail_cases = [-1, 64, 0x7FFFFFFF, 257, 256, 128, 65, -3, !0x7FFFFFFF]; diff --git a/src/movegen.rs b/src/movegen.rs index 2a0624d..145e4dd 100644 --- a/src/movegen.rs +++ b/src/movegen.rs @@ -188,6 +188,7 @@ impl Move { .expect("rook castling src square should be valid"); let rook_dest = Square::from_row_col(rook_row, rook_dest_col) .expect("rook castling dest square should be valid"); + debug_assert!(new_pos.get_piece(rook_src).is_some(), "rook castling src square has no rook (move: {rook_src} -> {rook_dest})"); new_pos.move_piece(rook_src, rook_dest); } debug_assert!( @@ -306,8 +307,9 @@ impl FromUCIAlgebraic for Move { /// Pseudo-legal move generation. /// -/// "Pseudo-legal" here means that moving into check is allowed, and capturing friendly pieces is -/// allowed. These will be filtered out in the legal move generation step. +/// "Pseudo-legal" here means that moving into check (but not castling through check) is allowed, +/// and capturing friendly pieces is allowed. These will be filtered out in the legal move +/// generation step. pub trait PseudoMoveGen { fn gen_pseudo_moves(&self) -> impl IntoIterator; } @@ -416,41 +418,69 @@ impl PseudoMoveGen for BoardState { } for src in squares!(King) { move_slider(self, src, &mut ret, SliderDirection::Star, false); - let (r, c) = src.to_row_col(); + let (r, c) = src.to_row_col_signed(); let rights = self.pl_castle(self.turn); - let castle_sides = [(rights.k, 2, BOARD_WIDTH - 1), (rights.q, -2, 0)]; + let castle_sides = [(rights.k, 2, BOARD_WIDTH as isize - 1), (rights.q, -2, 0)]; for (is_allowed, move_offset, endpoint) in castle_sides { if !is_allowed { continue; } - let range = if c < endpoint { + let (rook_r, rook_c) = (r, endpoint); + let rook_sq = Square::from_row_col_signed(rook_r, rook_c).unwrap(); + let rook_exists = self + .get_piece(rook_sq) + .map_or(false, |pc| pc.pc == Piece::Rook); + if !rook_exists { + continue; + } + + let path_range = if c < endpoint { (c + 1)..endpoint } else { (endpoint + 1)..c }; debug_assert_ne!( - range.len(), + path_range.len(), 0, "c {:?}, endpoint {:?}, range {:?}", c, endpoint, - range + path_range ); - let mut range_squares = range.map(|nc| Square::from_row_col(r, nc).unwrap()); - debug_assert_ne!(range_squares.len(), 0); + let mut path_squares = + path_range.map(|nc| Square::from_row_col_signed(r, nc).unwrap()); + debug_assert_ne!(path_squares.len(), 0); // find first blocking piece - let is_path_blocked = range_squares.find_map(|sq| self.get_piece(sq)).is_some(); + let is_path_blocked = path_squares.find_map(|sq| self.get_piece(sq)).is_some(); if is_path_blocked { continue; } - let nc: isize = c.try_into().unwrap(); - let dest = Square::from_row_col_signed(r.try_into().unwrap(), nc + move_offset) + let nc: isize = c + move_offset; + let dest = Square::from_row_col_signed(r, nc) .expect("Castle destination square should be valid"); + + debug_assert!(c.abs_diff(nc) == 2); + + // ensure the path is not being attacked (castle through check) + let check_range = if c < nc { c..=nc } else { nc..=c }; + debug_assert!(!check_range.is_empty()); + let is_any_checked = check_range + .map(|nc| Square::from_row_col_signed(r, nc).unwrap()) + .map(|dest| { + let mut board = *self; + board.move_piece(src, dest); + is_check(&board, self.turn) + }) + .any(|x| x); + if is_any_checked { + continue; + } + ret.push(Move { src, dest, @@ -711,6 +741,21 @@ mod tests { ), ], ), + // white castle test (blocked again) + ( + "8/8/8/8/8/8/r6r/R3K1nR w KQ - 0 1", + vec![ + // NOTE: pseudo-legal e1 + ("a1", vec!["b1", "a2", "c1", "d1", "e1"], MoveType::Normal), + ("h1", vec!["g1", "h2"], MoveType::Normal), + // NOTE: pseudo-legal d2, e2, f2, f1 + ( + "e1", + vec!["c1", "d1", "f1", "d2", "e2", "f2"], + MoveType::Normal, + ), + ], + ), // white castle test (no rights, blocked) ( "8/8/8/8/8/8/r6r/R3Kn1R w K - 0 1", @@ -997,6 +1042,52 @@ mod tests { ), ], ), + // castling through check + ( + "8/8/8/8/8/8/6rr/4K2R w KQ - 0 1", + vec![ + ("e1", vec!["d1", "f1"], MoveType::Normal), + ("h1", vec!["g1", "f1", "h2"], MoveType::Normal), + ], + ), + // castling through check + ( + "8/8/8/8/8/8/5r1r/4K2R w KQ - 0 1", + vec![ + ("e1", vec!["d1"], MoveType::Normal), + ("h1", vec!["g1", "f1", "h2"], MoveType::Normal), + ], + ), + // castling while checked + ( + "8/8/8/8/8/8/rrrrr2r/4K2R w KQ - 0 1", + vec![ + ("e1", vec!["f1"], MoveType::Normal), + ], + ), + // castling while checked + ( + "8/8/8/8/8/8/r3rrrr/R3K3 w KQ - 0 1", + vec![ + ("e1", vec!["d1"], MoveType::Normal), + ], + ), + // castling through check + ( + "8/8/8/8/8/8/r1r5/R3K3 w KQ - 0 1", + vec![ + ("e1", vec!["d1", "f1"], MoveType::Normal), + ("a1", vec!["a2", "b1", "c1", "d1"], MoveType::Normal), + ], + ), + // castling through check + ( + "8/8/8/8/8/8/r2r4/R3K3 w KQ - 0 1", + vec![ + ("e1", vec!["f1"], MoveType::Normal), + ("a1", vec!["a2", "b1", "c1", "d1"], MoveType::Normal), + ], + ), // check test ( "1bqnb1q1/3q1n2/q2p4/2bKp1q1/1r1b4/1q3n2/8/k2q2b1 w - - 0 1", @@ -1009,9 +1100,14 @@ mod tests { ), ]; - for (i, (fen, expected)) in test_cases.iter().enumerate() { + for tc in test_cases { + eprintln!("on test {}", tc.0); + + let (board, mut expected_moves) = decondense_moves(tc); + expected_moves.sort_unstable(); + let node = Node { - pos: BoardState::from_fen(fen).unwrap(), + pos: board, prev: None, }; @@ -1019,27 +1115,9 @@ mod tests { moves.sort_unstable(); let moves = moves; - let mut expected_moves = expected - .iter() - .map(|(src, dests, move_type)| { - let src = src.parse::().unwrap(); - let dests = dests - .iter() - .map(|x| x.parse::()) - .map(|x| x.unwrap()); - dests.map(move |dest| Move { - src, - dest, - move_type: *move_type, - }) - }) - .flatten() - .collect::>(); - - expected_moves.sort_unstable(); let expected_moves = expected_moves; - assert_eq!(moves, expected_moves, "failed test case {i} ({fen})"); + assert_eq!(moves, expected_moves); } }