feat: make move forfeits castling rights

This commit is contained in:
dogeystamp 2024-10-01 21:13:06 -04:00
parent ca0c17cbbe
commit 8804c0e1c4

View File

@ -1,7 +1,7 @@
//! Move generation. //! Move generation.
use crate::fen::{FromFen, ToFen, START_POSITION}; use crate::fen::{FromFen, ToFen, START_POSITION};
use crate::{BoardState, Color, Square, BOARD_WIDTH}; use crate::{BoardState, Color, Piece, Square, BOARD_HEIGHT, BOARD_WIDTH, N_SQUARES};
use std::rc::Rc; use std::rc::Rc;
/// Game tree node. /// Game tree node.
@ -71,7 +71,7 @@ impl Move {
Move::Castle(data) => todo!(), Move::Castle(data) => todo!(),
Move::Normal(data) => { Move::Normal(data) => {
let pc_src = node.pos.get_piece(data.src).unwrap(); let pc_src = node.pos.get_piece(data.src).unwrap();
if matches!(pc_src.pc, crate::Piece::Pawn) { if matches!(pc_src.pc, Piece::Pawn) {
// pawn moves are irreversible // pawn moves are irreversible
node.pos.half_moves = 0; node.pos.half_moves = 0;
@ -94,6 +94,30 @@ impl Move {
node.pos.ep_square = None; node.pos.ep_square = None;
} }
let castle = &mut node.pos.castle.0[pc_src.col as usize];
// forfeit castling rights
if matches!(pc_src.pc, Piece::King) {
castle.k = false;
castle.q = false;
} else if matches!(pc_src.pc, Piece::Rook) {
match pc_src.col {
Color::White => {
if data.src == Square(0) {
castle.q = false;
} else if data.src == Square(BOARD_WIDTH - 1) {
castle.k = false;
};
}
Color::Black => {
if data.src == Square((BOARD_HEIGHT - 1) * BOARD_WIDTH) {
castle.q = false;
} else if data.src == Square(N_SQUARES - 1) {
castle.k = false;
};
}
}
}
if let Some(_pc_dest) = node.pos.get_piece(data.dest) { if let Some(_pc_dest) = node.pos.get_piece(data.dest) {
// captures are irreversible // captures are irreversible
node.pos.half_moves = 0; node.pos.half_moves = 0;
@ -114,89 +138,119 @@ mod tests {
use super::*; use super::*;
use crate::fen::START_POSITION; use crate::fen::START_POSITION;
/// Test that make move and unmake move for simple piece pushes/captures. /// Test that make move and unmake move work as expected.
/// ///
/// Also tests en passant target square. /// Ensure that:
/// - En passant target is appropriately set
/// - Castling rights are respected
/// - Half-moves since last irreversible move counter is maintained
#[test] #[test]
fn test_normal_move() { fn test_make_unmake() {
let start_pos = START_POSITION;
// (src, dest, expected fen)
// FENs made with https://lichess.org/analysis // FENs made with https://lichess.org/analysis
// En-passant target square is manually added, since Lichess doesn't have it when // En-passant target square is manually added, since Lichess doesn't have it when
// en-passant is not legal. // en-passant is not legal.
let moves = [ let test_cases = [
( (
"e2", START_POSITION,
"e4", vec![
"rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1", // (src, dest, expected fen)
(
"e2",
"e4",
"rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1",
),
(
"e7",
"e5",
"rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq e6 0 2",
),
(
"g1",
"f3",
"rnbqkbnr/pppp1ppp/8/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R b KQkq - 1 2",
),
(
"g8",
"f6",
"rnbqkb1r/pppp1ppp/5n2/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 2 3",
),
(
"f1",
"c4",
"rnbqkb1r/pppp1ppp/5n2/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R b KQkq - 3 3",
),
(
"f8",
"c5",
"rnbqk2r/pppp1ppp/5n2/2b1p3/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 4 4",
),
(
"d1",
"e2",
"rnbqk2r/pppp1ppp/5n2/2b1p3/2B1P3/5N2/PPPPQPPP/RNB1K2R b KQkq - 5 4",
),
(
"d8",
"e7",
"rnb1k2r/ppppqppp/5n2/2b1p3/2B1P3/5N2/PPPPQPPP/RNB1K2R w KQkq - 6 5",
),
(
"f3",
"e5",
"rnb1k2r/ppppqppp/5n2/2b1N3/2B1P3/8/PPPPQPPP/RNB1K2R b KQkq - 0 5",
),
(
"e7",
"e5",
"rnb1k2r/pppp1ppp/5n2/2b1q3/2B1P3/8/PPPPQPPP/RNB1K2R w KQkq - 0 6",
),
],
), ),
// castling rights test (kings)
( (
"e7", "rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2",
"e5", vec![
"rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq e6 0 2", (
), "e1",
( "e2",
"g1", "rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPPKPPP/RNBQ1BNR b kq - 1 2",
"f3", ),
"rnbqkbnr/pppp1ppp/8/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R b KQkq - 1 2", (
), "e8",
( "e7",
"g8", "rnbq1bnr/ppppkppp/8/4p3/4P3/8/PPPPKPPP/RNBQ1BNR w - - 2 3",
"f6", ),
"rnbqkb1r/pppp1ppp/5n2/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 2 3", ],
),
(
"f1",
"c4",
"rnbqkb1r/pppp1ppp/5n2/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R b KQkq - 3 3",
),
(
"f8",
"c5",
"rnbqk2r/pppp1ppp/5n2/2b1p3/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 4 4",
),
(
"d1",
"e2",
"rnbqk2r/pppp1ppp/5n2/2b1p3/2B1P3/5N2/PPPPQPPP/RNB1K2R b KQkq - 5 4",
),
(
"d8",
"e7",
"rnb1k2r/ppppqppp/5n2/2b1p3/2B1P3/5N2/PPPPQPPP/RNB1K2R w KQkq - 6 5",
),
(
"f3",
"e5",
"rnb1k2r/ppppqppp/5n2/2b1N3/2B1P3/8/PPPPQPPP/RNB1K2R b KQkq - 0 5",
),
(
"e7",
"e5",
"rnb1k2r/pppp1ppp/5n2/2b1q3/2B1P3/8/PPPPQPPP/RNB1K2R w KQkq - 0 6",
), ),
]; ];
// make move for (i, test_case) in test_cases.iter().enumerate() {
let mut node = Node::default(); let (start_pos, moves) = test_case;
for (src, dest, expect_fen) in moves {
let idx_src = Square::from_algebraic(src.to_string()).unwrap();
let idx_dest = Square::from_algebraic(dest.to_string()).unwrap();
let mv = Move::Normal(MoveData {
src: idx_src,
dest: idx_dest,
});
node = mv.make(node);
assert_eq!(node.pos.to_fen(), expect_fen.to_string())
}
// unmake move // make move
let mut cur_node = Rc::new(node.clone()); println!("Starting test case {i}, make move.");
for (_, _, expect_fen) in moves.iter().rev().chain([("", "", START_POSITION)].iter()) { let mut node = Node {pos: BoardState::from_fen(start_pos.to_string()).unwrap(), prev: None};
println!("{}", expect_fen); for (src, dest, expect_fen) in moves {
assert_eq!(*cur_node.pos.to_fen(), expect_fen.to_string()); println!("Moving {src} to {dest}.");
if *expect_fen != START_POSITION { let idx_src = Square::from_algebraic(src.to_string()).unwrap();
cur_node = cur_node.prev.clone().unwrap(); let idx_dest = Square::from_algebraic(dest.to_string()).unwrap();
let mv = Move::Normal(MoveData {
src: idx_src,
dest: idx_dest,
});
node = mv.make(node);
assert_eq!(node.pos.to_fen(), expect_fen.to_string())
}
// unmake move
println!("Starting test case {i}, unmake move.");
let mut cur_node = Rc::new(node.clone());
for (_, _, expect_fen) in moves.iter().rev().chain([("", "", *start_pos)].iter()) {
println!("{}", expect_fen);
assert_eq!(*cur_node.pos.to_fen(), expect_fen.to_string());
if *expect_fen != *start_pos {
cur_node = cur_node.prev.clone().unwrap();
}
} }
} }
} }