stub: rook slider movegen

This commit is contained in:
dogeystamp 2024-10-20 11:37:44 -04:00
parent 5b65f9b756
commit 23f4ff68b4
3 changed files with 227 additions and 24 deletions

View File

@ -4,7 +4,7 @@ pub const START_POSITION: &str = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w
pub trait FromFen { pub trait FromFen {
type Error; type Error;
fn from_fen(_: String) -> Result<Self, Self::Error> fn from_fen(_: &str) -> Result<Self, Self::Error>
where where
Self: std::marker::Sized; Self: std::marker::Sized;
} }
@ -36,7 +36,7 @@ pub enum FenError {
impl FromFen for BoardState { impl FromFen for BoardState {
type Error = FenError; type Error = FenError;
fn from_fen(fen: String) -> Result<BoardState, FenError> { fn from_fen(fen: &str) -> Result<BoardState, FenError> {
//! Parse FEN string into position. //! Parse FEN string into position.
/// Parser state machine. /// Parser state machine.
@ -301,7 +301,7 @@ mod tests {
macro_rules! make_board { macro_rules! make_board {
($fen_fmt: expr) => { ($fen_fmt: expr) => {
BoardState::from_fen(format!($fen_fmt)).unwrap() BoardState::from_fen(&format!($fen_fmt)).unwrap()
}; };
} }
@ -437,7 +437,7 @@ mod tests {
for fen1 in test_cases { for fen1 in test_cases {
println!("fen1: {fen1:?}"); println!("fen1: {fen1:?}");
let fen2 = BoardState::from_fen(fen1.to_string()).unwrap().to_fen(); let fen2 = BoardState::from_fen(fen1).unwrap().to_fen();
assert_eq!(fen1.to_string(), fen2, "FEN not equivalent") assert_eq!(fen1.to_string(), fen2, "FEN not equivalent")
} }

View File

@ -1,5 +1,7 @@
#![deny(rust_2018_idioms)] #![deny(rust_2018_idioms)]
use std::str::FromStr;
pub mod fen; pub mod fen;
pub mod movegen; pub mod movegen;
@ -107,7 +109,7 @@ impl ColPiece {
/// Square index newtype. /// Square index newtype.
/// ///
/// A1 is (0, 0) -> 0, A2 is (0, 1) -> 2, and H8 is (7, 7) -> 63. /// A1 is (0, 0) -> 0, A2 is (0, 1) -> 2, and H8 is (7, 7) -> 63.
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
struct Square(usize); struct Square(usize);
#[derive(Debug)] #[derive(Debug)]
@ -127,6 +129,28 @@ impl TryFrom<usize> for Square {
} }
} }
} }
macro_rules! sq_try_from {
($T: ty) => {
impl TryFrom<$T> for Square {
type Error = SquareError;
fn try_from(value: $T) -> Result<Self, Self::Error> {
if let Ok(upper_bound) = <$T>::try_from(N_SQUARES) {
if (0..upper_bound).contains(&value) {
return Ok(Square(value as usize));
}
}
Err(SquareError::OutOfBounds)
}
}
};
}
sq_try_from!(i32);
sq_try_from!(isize);
sq_try_from!(i8);
impl From<Square> for usize { impl From<Square> for usize {
fn from(value: Square) -> Self { fn from(value: Square) -> Self {
value.0 value.0
@ -138,6 +162,10 @@ impl Square {
let ret = BOARD_WIDTH * r + c; let ret = BOARD_WIDTH * r + c;
ret.try_into() ret.try_into()
} }
fn from_row_col_signed(r: isize, c: isize) -> Result<Self, SquareError> {
let ret = (BOARD_WIDTH as isize) * r + c;
ret.try_into()
}
fn to_row_col(self) -> (usize, usize) { fn to_row_col(self) -> (usize, usize) {
//! Get row, column from index //! Get row, column from index
let div = self.0 / BOARD_WIDTH; let div = self.0 / BOARD_WIDTH;
@ -155,10 +183,14 @@ impl Square {
let file = letters[col]; let file = letters[col];
format!("{file}{rank}") format!("{file}{rank}")
} }
}
impl FromStr for Square {
type Err = SquareError;
/// Convert typical human-readable form (e.g. `e4`) to square index. /// Convert typical human-readable form (e.g. `e4`) to square index.
fn from_algebraic(value: &str) -> Result<Self, SquareError> { fn from_str(s: &str) -> Result<Self, Self::Err> {
let bytes = value.as_bytes(); let bytes = s.as_bytes();
let col = match bytes[0] as char { let col = match bytes[0] as char {
'a' => 0, 'a' => 0,
'b' => 1, 'b' => 1,
@ -289,8 +321,13 @@ struct Player {
} }
impl Player { impl Player {
/// Get board for a specific piece. /// Get board (non-mutable) for a specific piece.
fn board(&mut self, pc: Piece) -> &mut Bitboard { fn board(&self, pc: Piece) -> &Bitboard {
&self.bit[pc as usize]
}
/// Get board (mutable) for a specific piece.
fn board_mut(&mut self, pc: Piece) -> &mut Bitboard {
&mut self.bit[pc as usize] &mut self.bit[pc as usize]
} }
} }
@ -371,7 +408,7 @@ impl BoardState {
/// Create a new piece in a location. /// Create a new piece in a location.
fn set_piece(&mut self, idx: Square, pc: ColPiece) { fn set_piece(&mut self, idx: Square, pc: ColPiece) {
let pl = self.pl_mut(pc.col); let pl = self.pl_mut(pc.col);
pl.board(pc.into()).on_idx(idx); pl.board_mut(pc.into()).on_idx(idx);
*self.mail.sq_mut(idx) = Some(pc); *self.mail.sq_mut(idx) = Some(pc);
} }
@ -381,7 +418,7 @@ impl BoardState {
fn del_piece(&mut self, idx: Square) -> Result<ColPiece, NoPieceError> { fn del_piece(&mut self, idx: Square) -> Result<ColPiece, NoPieceError> {
if let Some(pc) = *self.mail.sq_mut(idx) { if let Some(pc) = *self.mail.sq_mut(idx) {
let pl = self.pl_mut(pc.col); let pl = self.pl_mut(pc.col);
pl.board(pc.into()).off_idx(idx); pl.board_mut(pc.into()).off_idx(idx);
*self.mail.sq_mut(idx) = None; *self.mail.sq_mut(idx) = None;
Ok(pc) Ok(pc)
} else { } else {
@ -422,13 +459,46 @@ impl core::fmt::Display for BoardState {
mod tests { mod tests {
use super::*; use super::*;
#[test]
fn test_square_casts() {
let fail_cases = [-1, 64, 0x7FFFFFFF, 257, 256, 128, 65, -3, !0x7FFFFFFF];
for tc in fail_cases {
macro_rules! try_type {
($T: ty) => {
if let Ok(conv) = <$T>::try_from(tc) {
assert!(matches!(Square::try_from(conv), Err(SquareError::OutOfBounds)))
}
};
}
try_type!(i32);
try_type!(i8);
try_type!(isize);
try_type!(usize);
}
let good_cases = 0..N_SQUARES;
for tc in good_cases {
macro_rules! try_type {
($T: ty) => {
let conv = <$T>::try_from(tc).unwrap();
let res = Square::try_from(conv).unwrap();
assert_eq!(res.0, tc);
};
}
try_type!(i32);
try_type!(i8);
try_type!(isize);
try_type!(usize);
}
}
#[test] #[test]
fn test_to_from_algebraic() { fn test_to_from_algebraic() {
let test_cases = [("a1", 0), ("a8", 56), ("h1", 7), ("h8", 63)]; let test_cases = [("a1", 0), ("a8", 56), ("h1", 7), ("h8", 63)];
for (sqr, idx) in test_cases { for (sqr, idx) in test_cases {
assert_eq!(Square::try_from(idx).unwrap().to_algebraic(), sqr); assert_eq!(Square::try_from(idx).unwrap().to_algebraic(), sqr);
assert_eq!( assert_eq!(
Square::from_algebraic(sqr).unwrap(), sqr.parse::<Square>().unwrap(),
Square::try_from(idx).unwrap() Square::try_from(idx).unwrap()
); );
} }

View File

@ -18,15 +18,14 @@ struct Node {
impl Default for Node { impl Default for Node {
fn default() -> Self { fn default() -> Self {
Node { Node {
pos: BoardState::from_fen(START_POSITION.to_string()) pos: BoardState::from_fen(START_POSITION).expect("Starting FEN should be valid"),
.expect("Starting FEN should be valid"),
prev: None, prev: None,
} }
} }
} }
/// Piece enum specifically for promotions. /// Piece enum specifically for promotions.
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
enum PromotePiece { enum PromotePiece {
Rook, Rook,
Bishop, Bishop,
@ -45,17 +44,18 @@ impl From<PromotePiece> for Piece {
} }
} }
/// Pseudo-legal move. #[derive(PartialEq, Eq, PartialOrd, Ord)]
///
/// No checking is done when constructing this.
enum MoveType { enum MoveType {
/// Pawn promotes to another piece. /// Pawn promotes to another piece.
Promotion(PromotePiece), Promotion(PromotePiece),
/// Capture, or push move. Includes castling and en-passant too. /// Capture, or push move. Includes castling and en-passant too.
Normal, Normal,
} }
/// Move data common to all move types. /// Pseudo-legal move.
struct Move { ///
/// No checking is done when constructing this.
#[derive(PartialEq, Eq, PartialOrd, Ord)]
pub struct Move {
src: Square, src: Square,
dest: Square, dest: Square,
move_type: MoveType, move_type: MoveType,
@ -232,7 +232,7 @@ pub trait ToUCIAlgebraic {
} }
#[derive(Debug)] #[derive(Debug)]
enum MoveAlgebraicError { pub enum MoveAlgebraicError {
/// String is invalid length; refuse to parse /// String is invalid length; refuse to parse
InvalidLength(usize), InvalidLength(usize),
/// Invalid character at given index. /// Invalid character at given index.
@ -250,14 +250,14 @@ impl FromUCIAlgebraic for Move {
return Err(MoveAlgebraicError::InvalidLength(value_len)); return Err(MoveAlgebraicError::InvalidLength(value_len));
} }
let src_sq = match Square::from_algebraic(&value[0..=1]) { let src_sq = match value[0..=1].parse::<Square>() {
Ok(sq) => sq, Ok(sq) => sq,
Err(e) => { Err(e) => {
return Err(MoveAlgebraicError::SquareError(0, e)); return Err(MoveAlgebraicError::SquareError(0, e));
} }
}; };
let dest_sq = match Square::from_algebraic(&value[2..=3]) { let dest_sq = match value[2..=3].parse::<Square>() {
Ok(sq) => sq, Ok(sq) => sq,
Err(e) => { Err(e) => {
return Err(MoveAlgebraicError::SquareError(0, e)); return Err(MoveAlgebraicError::SquareError(0, e));
@ -285,11 +285,144 @@ impl FromUCIAlgebraic for Move {
} }
} }
/// Pseudo-legal move generation.
///
/// "Pseudo-legal" here means that moving into check is allowed, and capturing friendly pieces is
/// allowed. These will be filtered out in the legal move generation step.
pub trait PseudoMoveGen {
type MoveIterable;
fn gen_pseudo_moves(self) -> Self::MoveIterable;
}
enum SliderDirection {
/// Rook movement
Straight,
/// Bishop movement
Diagonal,
/// Queen/king movement
Star,
}
/// Generate slider moves for a given square.
///
/// # Arguments
///
/// * `board`: Board to generate moves with.
/// * `src`: Square on which the slider piece is on.
/// * `move_list`: Vector to append generated moves to.
/// * `slide_type`: Directions the piece is allowed to go in.
/// * `keep_going`: Allow sliding more than one square (true for everything except king).
fn move_slider(
board: &BoardState,
src: Square,
move_list: &mut Vec<Move>,
slide_type: SliderDirection,
keep_going: bool,
) {
let dirs_straight = [(0, 1), (1, 0), (-1, 0), (0, -1)];
let dirs_diag = [(1, 1), (1, -1), (-1, 1), (-1, -1)];
let dirs_star = [
(1, 1),
(1, -1),
(-1, 1),
(-1, -1),
(0, 1),
(1, 0),
(-1, 0),
(0, -1),
];
let dirs = match slide_type {
SliderDirection::Straight => dirs_straight.iter(),
SliderDirection::Diagonal => dirs_diag.iter(),
SliderDirection::Star => dirs_star.iter(),
};
for dir in dirs {
let (mut r, mut c) = src.to_row_col();
loop {
// increment
let nr = r as isize + dir.0;
let nc = c as isize + dir.1;
if let Ok(dest) = Square::from_row_col_signed(nr, nc) {
r = nr as usize;
c = nc as usize;
move_list.push(Move {
src,
dest,
move_type: MoveType::Normal,
});
// Stop at other pieces.
if let Some(_cap_pc) = board.get_piece(dest) {
break;
}
} else {
break;
}
if !keep_going {
break;
}
}
}
}
impl PseudoMoveGen for BoardState {
type MoveIterable = Vec<Move>;
fn gen_pseudo_moves(self) -> Self::MoveIterable {
let mut ret = Vec::new();
for pl in self.players {
for sq in pl.board(Piece::Rook).into_iter() {
move_slider(&self, sq, &mut ret, SliderDirection::Straight, true);
}
}
ret
}
}
/// Legal move generation.
pub trait LegalMoveGen {
type MoveIterable;
fn gen_moves(self) -> Self::MoveIterable;
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::fen::{ToFen, START_POSITION}; use crate::fen::{ToFen, START_POSITION};
/// Test that slider pieces can move and capture.
#[test]
fn test_slider_movegen() {
let test_cases = [(
// start position
"8/8/8/8/8/8/8/R7 w - - 0 1",
// expected moves
vec![(
// source piece
"a1",
// destination squares
vec![
"a2", "a3", "a4", "a5", "a6", "a7", "a8", "b1", "c1", "d1", "e1", "f1", "g1",
"h1",
],
)],
)];
for (fen, expected) in test_cases {
let board = BoardState::from_fen(fen).unwrap();
let mut moves = board.gen_pseudo_moves();
moves.sort_unstable();
let moves = moves;
let expected_moves = expected.iter().map(|(src, dests)| {});
}
}
/// Test that make move and unmake move work as expected. /// Test that make move and unmake move work as expected.
/// ///
/// Ensure that: /// Ensure that:
@ -405,7 +538,7 @@ mod tests {
// make move // make move
eprintln!("Starting test case {i}, make move."); eprintln!("Starting test case {i}, make move.");
let mut node = Node { let mut node = Node {
pos: BoardState::from_fen(start_pos.to_string()).unwrap(), pos: BoardState::from_fen(start_pos).unwrap(),
prev: None, prev: None,
}; };
for (move_str, expect_fen) in moves { for (move_str, expect_fen) in moves {