feat: construct move from uci algebraic

This commit is contained in:
dogeystamp 2024-10-04 22:25:16 -04:00
parent 98b4f116d6
commit f73d34ee41
2 changed files with 123 additions and 76 deletions

View File

@ -435,7 +435,7 @@ mod tests {
"8/8/2k5/5q2/5n2/8/5K2/8 b - - 0 1", "8/8/2k5/5q2/5n2/8/5K2/8 b - - 0 1",
]; ];
for (i, fen1) in test_cases.iter().enumerate() { for fen1 in test_cases {
println!("fen1: {fen1:?}"); println!("fen1: {fen1:?}");
let fen2 = BoardState::from_fen(fen1.to_string()).unwrap().to_fen(); let fen2 = BoardState::from_fen(fen1.to_string()).unwrap().to_fen();

View File

@ -1,7 +1,9 @@
//! Move generation. //! Move generation.
use crate::fen::{FromFen, ToFen, START_POSITION}; use crate::fen::{FromFen, START_POSITION};
use crate::{BoardState, ColPiece, Color, Piece, Square, BOARD_HEIGHT, BOARD_WIDTH, N_SQUARES}; use crate::{
BoardState, ColPiece, Color, Piece, Square, SquareError, BOARD_HEIGHT, BOARD_WIDTH, N_SQUARES,
};
use std::rc::Rc; use std::rc::Rc;
/// Game tree node. /// Game tree node.
@ -43,23 +45,20 @@ impl From<PromotePiece> for Piece {
} }
} }
/// Move data common to all move types.
struct MoveData {
src: Square,
dest: Square,
}
/// Pseudo-legal move. /// Pseudo-legal move.
/// ///
/// No checking is made to see if the move is actually pseudo-legal. /// No checking is done when constructing this.
enum Move { enum MoveType {
/// Pawn promotes to another piece. /// Pawn promotes to another piece.
Promotion(MoveData, PromotePiece), Promotion(PromotePiece),
/// King castles with rook. /// Capture, or push move. Includes castling and en-passant too.
Castle(MoveData), Normal,
/// Capture, or push move. }
Normal(MoveData), /// Move data common to all move types.
/// This move is an en-passant capture. struct Move {
EnPassant(MoveData), src: Square,
dest: Square,
move_type: MoveType,
} }
impl Move { impl Move {
@ -93,38 +92,37 @@ impl Move {
}; };
} }
match self { match self.move_type {
Move::Promotion(data, to_piece) => { MoveType::Promotion(to_piece) => {
let pc_src = pc_src!(data); let pc_src = pc_src!(self);
pc_asserts!(pc_src, data); pc_asserts!(pc_src, self);
debug_assert_eq!(pc_src.pc, Piece::Pawn); debug_assert_eq!(pc_src.pc, Piece::Pawn);
node.pos.del_piece(data.src); node.pos.del_piece(self.src);
node.pos.set_piece( node.pos.set_piece(
data.dest, self.dest,
ColPiece { ColPiece {
pc: Piece::from(to_piece), pc: Piece::from(to_piece),
col: pc_src.col, col: pc_src.col,
}, },
); );
} }
Move::Castle(data) => todo!(), MoveType::Normal => {
Move::Normal(data) => { let pc_src = pc_src!(self);
let pc_src = pc_src!(data); pc_asserts!(pc_src, self);
pc_asserts!(pc_src, data);
if matches!(pc_src.pc, Piece::Pawn) { if matches!(pc_src.pc, Piece::Pawn) {
// pawn moves are irreversible // pawn moves are irreversible
node.pos.half_moves = 0; node.pos.half_moves = 0;
// en-passant // en-passant
if data.src.0 + (BOARD_WIDTH) * 2 == data.dest.0 { if self.src.0 + (BOARD_WIDTH) * 2 == self.dest.0 {
node.pos.ep_square = Some( node.pos.ep_square = Some(
Square::try_from(data.src.0 + BOARD_WIDTH) Square::try_from(self.src.0 + BOARD_WIDTH)
.expect("En-passant target should be valid."), .expect("En-passant target should be valid."),
) )
} else if data.dest.0 + (BOARD_WIDTH) * 2 == data.src.0 { } else if self.dest.0 + (BOARD_WIDTH) * 2 == self.src.0 {
node.pos.ep_square = Some( node.pos.ep_square = Some(
Square::try_from(data.src.0 - BOARD_WIDTH) Square::try_from(self.src.0 - BOARD_WIDTH)
.expect("En-passant target should be valid."), .expect("En-passant target should be valid."),
) )
} else { } else {
@ -143,31 +141,30 @@ impl Move {
} else if matches!(pc_src.pc, Piece::Rook) { } else if matches!(pc_src.pc, Piece::Rook) {
match pc_src.col { match pc_src.col {
Color::White => { Color::White => {
if data.src == Square(0) { if self.src == Square(0) {
castle.q = false; castle.q = false;
} else if data.src == Square(BOARD_WIDTH - 1) { } else if self.src == Square(BOARD_WIDTH - 1) {
castle.k = false; castle.k = false;
}; };
} }
Color::Black => { Color::Black => {
if data.src == Square((BOARD_HEIGHT - 1) * BOARD_WIDTH) { if self.src == Square((BOARD_HEIGHT - 1) * BOARD_WIDTH) {
castle.q = false; castle.q = false;
} else if data.src == Square(N_SQUARES - 1) { } else if self.src == Square(N_SQUARES - 1) {
castle.k = false; castle.k = false;
}; };
} }
} }
} }
if let Some(_pc_dest) = node.pos.get_piece(data.dest) { if let Some(_pc_dest) = node.pos.get_piece(self.dest) {
// captures are irreversible // captures are irreversible
node.pos.half_moves = 0; node.pos.half_moves = 0;
} }
node.pos.del_piece(data.src); node.pos.del_piece(self.src);
node.pos.set_piece(data.dest, pc_src); node.pos.set_piece(self.dest, pc_src);
} }
Move::EnPassant(data) => todo!(),
} }
node.pos.turn = node.pos.turn.flip(); node.pos.turn = node.pos.turn.flip();
@ -176,10 +173,77 @@ impl Move {
} }
} }
/// Convert from UCI long algebraic move notation.
pub trait FromUCIAlgebraic {
type Error;
fn from_uci_algebraic(value: &str) -> Result<Self, Self::Error>
where
Self: std::marker::Sized;
}
/// Convert to UCI long algebraic move notation.
pub trait ToUCIAlgebraic {
fn to_uci_algebraic(&self) -> String;
}
#[derive(Debug)]
enum MoveAlgebraicError {
/// String is invalid length; refuse to parse
InvalidLength(usize),
/// Invalid character at given index.
InvalidCharacter(usize),
/// Could not parse square string at a certain index.
SquareError(usize, SquareError),
}
impl FromUCIAlgebraic for Move {
type Error = MoveAlgebraicError;
fn from_uci_algebraic(value: &str) -> Result<Self, Self::Error> {
let value_len = value.len();
if !(4..=5).contains(&value_len) {
return Err(MoveAlgebraicError::InvalidLength(value_len));
}
let src_sq = match Square::from_algebraic(&value[0..=1]) {
Ok(sq) => sq,
Err(e) => {
return Err(MoveAlgebraicError::SquareError(0, e));
}
};
let dest_sq = match Square::from_algebraic(&value[2..=3]) {
Ok(sq) => sq,
Err(e) => {
return Err(MoveAlgebraicError::SquareError(0, e));
}
};
let mut move_type = MoveType::Normal;
if value_len == 5 {
let promote_char = value.as_bytes()[4] as char;
match promote_char {
'q' => move_type = MoveType::Promotion(PromotePiece::Queen),
'b' => move_type = MoveType::Promotion(PromotePiece::Bishop),
'n' => move_type = MoveType::Promotion(PromotePiece::Knight),
'r' => move_type = MoveType::Promotion(PromotePiece::Rook),
_ => return Err(MoveAlgebraicError::InvalidCharacter(4)),
}
}
Ok(Move {
src: src_sq,
dest: dest_sq,
move_type,
})
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::fen::START_POSITION; use crate::fen::{START_POSITION, ToFen};
/// Test that make move and unmake move work as expected. /// Test that make move and unmake move work as expected.
/// ///
@ -198,53 +262,43 @@ mod tests {
vec![ vec![
// (src, dest, expected fen) // (src, dest, expected fen)
( (
"e2", "e2e4",
"e4",
"rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1", "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1",
), ),
( (
"e7", "e7e5",
"e5",
"rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq e6 0 2", "rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq e6 0 2",
), ),
( (
"g1", "g1f3",
"f3",
"rnbqkbnr/pppp1ppp/8/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R b KQkq - 1 2", "rnbqkbnr/pppp1ppp/8/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R b KQkq - 1 2",
), ),
( (
"g8", "g8f6",
"f6",
"rnbqkb1r/pppp1ppp/5n2/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 2 3", "rnbqkb1r/pppp1ppp/5n2/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 2 3",
), ),
( (
"f1", "f1c4",
"c4",
"rnbqkb1r/pppp1ppp/5n2/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R b KQkq - 3 3", "rnbqkb1r/pppp1ppp/5n2/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R b KQkq - 3 3",
), ),
( (
"f8", "f8c5",
"c5",
"rnbqk2r/pppp1ppp/5n2/2b1p3/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 4 4", "rnbqk2r/pppp1ppp/5n2/2b1p3/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 4 4",
), ),
( (
"d1", "d1e2",
"e2",
"rnbqk2r/pppp1ppp/5n2/2b1p3/2B1P3/5N2/PPPPQPPP/RNB1K2R b KQkq - 5 4", "rnbqk2r/pppp1ppp/5n2/2b1p3/2B1P3/5N2/PPPPQPPP/RNB1K2R b KQkq - 5 4",
), ),
( (
"d8", "d8e7",
"e7",
"rnb1k2r/ppppqppp/5n2/2b1p3/2B1P3/5N2/PPPPQPPP/RNB1K2R w KQkq - 6 5", "rnb1k2r/ppppqppp/5n2/2b1p3/2B1P3/5N2/PPPPQPPP/RNB1K2R w KQkq - 6 5",
), ),
( (
"f3", "f3e5",
"e5",
"rnb1k2r/ppppqppp/5n2/2b1N3/2B1P3/8/PPPPQPPP/RNB1K2R b KQkq - 0 5", "rnb1k2r/ppppqppp/5n2/2b1N3/2B1P3/8/PPPPQPPP/RNB1K2R b KQkq - 0 5",
), ),
( (
"e7", "e7e5",
"e5",
"rnb1k2r/pppp1ppp/5n2/2b1q3/2B1P3/8/PPPPQPPP/RNB1K2R w KQkq - 0 6", "rnb1k2r/pppp1ppp/5n2/2b1q3/2B1P3/8/PPPPQPPP/RNB1K2R w KQkq - 0 6",
), ),
], ],
@ -254,13 +308,11 @@ mod tests {
"rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2", "rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2",
vec![ vec![
( (
"e1", "e1e2",
"e2",
"rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPPKPPP/RNBQ1BNR b kq - 1 2", "rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPPKPPP/RNBQ1BNR b kq - 1 2",
), ),
( (
"e8", "e8e7",
"e7",
"rnbq1bnr/ppppkppp/8/4p3/4P3/8/PPPPKPPP/RNBQ1BNR w - - 2 3", "rnbq1bnr/ppppkppp/8/4p3/4P3/8/PPPPKPPP/RNBQ1BNR w - - 2 3",
), ),
], ],
@ -271,28 +323,23 @@ mod tests {
let (start_pos, moves) = test_case; let (start_pos, moves) = test_case;
// make move // make move
println!("Starting test case {i}, make move."); eprintln!("Starting test case {i}, make move.");
let mut node = Node { let mut node = Node {
pos: BoardState::from_fen(start_pos.to_string()).unwrap(), pos: BoardState::from_fen(start_pos.to_string()).unwrap(),
prev: None, prev: None,
}; };
for (src, dest, expect_fen) in moves { for (move_str, expect_fen) in moves {
println!("Moving {src} to {dest}."); let mv = Move::from_uci_algebraic(move_str).unwrap();
let idx_src = Square::from_algebraic(src).unwrap(); eprintln!("Moving {move_str}.");
let idx_dest = Square::from_algebraic(dest).unwrap();
let mv = Move::Normal(MoveData {
src: idx_src,
dest: idx_dest,
});
node = mv.make(node); node = mv.make(node);
assert_eq!(node.pos.to_fen(), expect_fen.to_string()) assert_eq!(node.pos.to_fen(), expect_fen.to_string())
} }
// unmake move // unmake move
println!("Starting test case {i}, unmake move."); eprintln!("Starting test case {i}, unmake move.");
let mut cur_node = Rc::new(node.clone()); let mut cur_node = Rc::new(node.clone());
for (_, _, expect_fen) in moves.iter().rev().chain([("", "", *start_pos)].iter()) { for (_, expect_fen) in moves.iter().rev().chain([("", *start_pos)].iter()) {
println!("{}", expect_fen); eprintln!("{}", expect_fen);
assert_eq!(*cur_node.pos.to_fen(), expect_fen.to_string()); assert_eq!(*cur_node.pos.to_fen(), expect_fen.to_string());
if *expect_fen != *start_pos { if *expect_fen != *start_pos {
cur_node = cur_node.prev.clone().unwrap(); cur_node = cur_node.prev.clone().unwrap();