refactor: movegen

This commit is contained in:
dogeystamp 2024-10-25 21:25:38 -04:00
parent e4b19b8bdd
commit 3ed7f315c8
3 changed files with 102 additions and 110 deletions

View File

@ -1,14 +1,14 @@
//! Generates moves from the FEN in the argv. //! Generates moves from the FEN in the argv.
use chess_inator::fen::FromFen; use chess_inator::fen::FromFen;
use chess_inator::movegen::LegalMoveGen; use chess_inator::movegen::{MoveGen, MoveGenType};
use chess_inator::Board; use chess_inator::Board;
use std::env; use std::env;
fn main() { fn main() {
let fen = env::args().nth(1).unwrap(); let fen = env::args().nth(1).unwrap();
let board = Board::from_fen(&fen).unwrap(); let mut board = Board::from_fen(&fen).unwrap();
let mvs = board.gen_moves(); let mvs = board.gen_moves(MoveGenType::Legal);
for mv in mvs.into_iter() { for mv in mvs.into_iter() {
println!("{mv:?}") println!("{mv:?}")
} }

View File

@ -7,6 +7,7 @@ pub mod fen;
pub mod movegen; pub mod movegen;
use crate::fen::{FromFen, ToFen, START_POSITION}; use crate::fen::{FromFen, ToFen, START_POSITION};
use crate::movegen::Move;
const BOARD_WIDTH: usize = 8; const BOARD_WIDTH: usize = 8;
const BOARD_HEIGHT: usize = 8; const BOARD_HEIGHT: usize = 8;
@ -529,6 +530,55 @@ impl Board {
new_board new_board
} }
/// Is a given player in check?
fn is_check(&self, pl: Color) -> bool {
for src in self.pl(pl).board(Piece::King).into_iter() {
macro_rules! detect_checker {
($dirs: ident, $pc: pat, $keep_going: expr) => {
for dir in $dirs.into_iter() {
let (mut r, mut c) = src.to_row_col_signed();
loop {
let (nr, nc) = (r + dir.0, c + dir.1);
if let Ok(sq) = Square::from_row_col_signed(nr, nc) {
if let Some(pc) = self.get_piece(sq) {
if matches!(pc.pc, $pc) && pc.col != pl {
return true;
} else {
break;
}
}
} else {
break;
}
if (!($keep_going)) {
break;
}
r = nr;
c = nc;
}
}
};
}
let dirs_white_pawn = [(-1, 1), (-1, -1)];
let dirs_black_pawn = [(1, 1), (1, -1)];
use Piece::*;
use movegen::{DIRS_DIAG, DIRS_KNIGHT, DIRS_STAR, DIRS_STRAIGHT};
detect_checker!(DIRS_DIAG, Bishop | Queen, true);
detect_checker!(DIRS_STRAIGHT, Rook | Queen, true);
detect_checker!(DIRS_STAR, King, false);
detect_checker!(DIRS_KNIGHT, Knight, false);
match pl {
Color::White => detect_checker!(dirs_black_pawn, Pawn, false),
Color::Black => detect_checker!(dirs_white_pawn, Pawn, false),
}
}
false
}
/// Maximum amount of moves in the counter to parse before giving up /// Maximum amount of moves in the counter to parse before giving up
const MAX_MOVES: usize = 9_999; const MAX_MOVES: usize = 9_999;
} }

View File

@ -5,7 +5,6 @@ use crate::{
Board, CastleRights, ColPiece, Color, Piece, Square, SquareError, BOARD_HEIGHT, BOARD_WIDTH, Board, CastleRights, ColPiece, Color, Piece, Square, SquareError, BOARD_HEIGHT, BOARD_WIDTH,
N_SQUARES, N_SQUARES,
}; };
use std::rc::Rc;
/// Piece enum specifically for promotions. /// Piece enum specifically for promotions.
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
@ -394,18 +393,21 @@ impl FromUCIAlgebraic for Move {
} }
} }
/// Pseudo-legal move generation. #[derive(Debug, Clone, Copy)]
/// pub enum MoveGenType {
/// "Pseudo-legal" here means that moving into check (but not castling through check) is allowed, /// Legal move generation.
/// and capturing friendly pieces is allowed. These will be filtered out in the legal move Legal,
/// generation step. /// Allow capturing friendly pieces, moving into check, but not castling through check.
pub trait PseudoMoveGen { Pseudo,
fn gen_pseudo_moves(&self) -> impl IntoIterator<Item = Move>;
} }
const DIRS_STRAIGHT: [(isize, isize); 4] = [(0, 1), (1, 0), (-1, 0), (0, -1)]; pub trait MoveGen {
const DIRS_DIAG: [(isize, isize); 4] = [(1, 1), (1, -1), (-1, 1), (-1, -1)]; fn gen_moves(&mut self, gen_type: MoveGenType) -> impl IntoIterator<Item = Move>;
const DIRS_STAR: [(isize, isize); 8] = [ }
pub const DIRS_STRAIGHT: [(isize, isize); 4] = [(0, 1), (1, 0), (-1, 0), (0, -1)];
pub const DIRS_DIAG: [(isize, isize); 4] = [(1, 1), (1, -1), (-1, 1), (-1, -1)];
pub const DIRS_STAR: [(isize, isize); 8] = [
(1, 1), (1, 1),
(1, -1), (1, -1),
(-1, 1), (-1, 1),
@ -415,7 +417,7 @@ const DIRS_STAR: [(isize, isize); 8] = [
(-1, 0), (-1, 0),
(0, -1), (0, -1),
]; ];
const DIRS_KNIGHT: [(isize, isize); 8] = [ pub const DIRS_KNIGHT: [(isize, isize); 8] = [
(2, 1), (2, 1),
(1, 2), (1, 2),
(-1, 2), (-1, 2),
@ -486,9 +488,31 @@ fn move_slider(
} }
} }
} }
fn is_legal(board: &mut Board, mv: Move) -> bool {
// mut required for check checking
// disallow friendly fire
let src_pc = board
.get_piece(mv.src)
.expect("move source should have piece");
if let Some(dest_pc) = board.get_piece(mv.dest) {
if dest_pc.col == src_pc.col {
return false;
}
}
impl PseudoMoveGen for Board { // disallow moving into check
fn gen_pseudo_moves(&self) -> impl IntoIterator<Item = Move> { let anti_move = mv.make(board);
let is_check = board.is_check(board.turn.flip());
anti_move.unmake(board);
if is_check {
return false;
}
true
}
impl MoveGen for Board {
fn gen_moves(&mut self, gen_type: MoveGenType) -> impl IntoIterator<Item = Move> {
let mut ret = Vec::new(); let mut ret = Vec::new();
let pl = self.pl(self.turn); let pl = self.pl(self.turn);
macro_rules! squares { macro_rules! squares {
@ -564,7 +588,7 @@ impl PseudoMoveGen for Board {
.map(|dest| { .map(|dest| {
let mut board = *self; let mut board = *self;
board.move_piece(src, dest); board.move_piece(src, dest);
is_check(&board, self.turn) board.is_check(self.turn)
}) })
.any(|x| x); .any(|x| x);
if is_any_checked { if is_any_checked {
@ -670,87 +694,10 @@ impl PseudoMoveGen for Board {
} }
} }
} }
ret ret.into_iter().filter(move |mv| match gen_type {
} MoveGenType::Legal => is_legal(self, *mv),
} MoveGenType::Pseudo => true,
})
/// Legal move generation.
pub trait LegalMoveGen {
fn gen_moves(&self) -> impl IntoIterator<Item = Move>;
}
/// Is a given player in check?
fn is_check(board: &Board, pl: Color) -> bool {
for src in board.pl(pl).board(Piece::King).into_iter() {
macro_rules! detect_checker {
($dirs: ident, $pc: pat, $keep_going: expr) => {
for dir in $dirs.into_iter() {
let (mut r, mut c) = src.to_row_col_signed();
loop {
let (nr, nc) = (r + dir.0, c + dir.1);
if let Ok(sq) = Square::from_row_col_signed(nr, nc) {
if let Some(pc) = board.get_piece(sq) {
if matches!(pc.pc, $pc) && pc.col != pl {
return true;
} else {
break;
}
}
} else {
break;
}
if (!($keep_going)) {
break;
}
r = nr;
c = nc;
}
}
};
}
let dirs_white_pawn = [(-1, 1), (-1, -1)];
let dirs_black_pawn = [(1, 1), (1, -1)];
use Piece::*;
detect_checker!(DIRS_DIAG, Bishop | Queen, true);
detect_checker!(DIRS_STRAIGHT, Rook | Queen, true);
detect_checker!(DIRS_STAR, King, false);
detect_checker!(DIRS_KNIGHT, Knight, false);
match pl {
Color::White => detect_checker!(dirs_black_pawn, Pawn, false),
Color::Black => detect_checker!(dirs_white_pawn, Pawn, false),
}
}
false
}
impl LegalMoveGen for Board {
// mut required for check checking
fn gen_moves(&self) -> impl IntoIterator<Item = Move> {
let mut pos = *self;
pos.gen_pseudo_moves()
.into_iter()
.filter(|mv| {
// disallow friendly fire
let src_pc = self
.get_piece(mv.src)
.expect("move source should have piece");
if let Some(dest_pc) = self.get_piece(mv.dest) {
return dest_pc.col != src_pc.col;
}
true
})
.filter(move |mv| {
// disallow moving into check
let anti_move = mv.make(&mut pos);
let ret = !is_check(&pos, self.turn);
anti_move.unmake(&mut pos);
ret
})
.collect::<Vec<_>>()
} }
} }
@ -762,7 +709,7 @@ pub fn perft(depth: usize, pos: &mut Board) -> usize {
let mut ans = 0; let mut ans = 0;
let moves: Vec<Move> = pos.gen_moves().into_iter().collect(); let moves: Vec<Move> = pos.gen_moves(MoveGenType::Legal).into_iter().collect();
for mv in moves { for mv in moves {
let anti_move = mv.make(pos); let anti_move = mv.make(pos);
ans += perft(depth - 1, pos); ans += perft(depth - 1, pos);
@ -1083,8 +1030,8 @@ mod tests {
let augmented_test_cases = test_cases.clone().map(|tc| flip_test_case(tc.0, &tc.1)); let augmented_test_cases = test_cases.clone().map(|tc| flip_test_case(tc.0, &tc.1));
let all_cases = [augmented_test_cases, test_cases].concat(); let all_cases = [augmented_test_cases, test_cases].concat();
for (board, expected_moves) in all_cases { for (mut board, expected_moves) in all_cases {
let mut moves: Vec<Move> = board.gen_pseudo_moves().into_iter().collect(); let mut moves: Vec<Move> = board.gen_moves(MoveGenType::Pseudo).into_iter().collect();
moves.sort_unstable(); moves.sort_unstable();
let moves = moves; let moves = moves;
@ -1121,16 +1068,11 @@ mod tests {
let all_cases = check_cases.iter().chain(&not_check_cases); let all_cases = check_cases.iter().chain(&not_check_cases);
for (fen, expected) in all_cases { for (fen, expected) in all_cases {
let board = Board::from_fen(fen).unwrap(); let board = Board::from_fen(fen).unwrap();
assert_eq!( assert_eq!(board.is_check(Color::White), *expected, "failed on {}", fen);
is_check(&board, Color::White),
*expected,
"failed on {}",
fen
);
let board_anti = board.flip_colors(); let board_anti = board.flip_colors();
assert_eq!( assert_eq!(
is_check(&board_anti, Color::Black), board_anti.is_check(Color::Black),
*expected, *expected,
"failed on anti-version of {} ({})", "failed on anti-version of {} ({})",
fen, fen,
@ -1244,7 +1186,7 @@ mod tests {
expected_moves.sort_unstable(); expected_moves.sort_unstable();
let expected_moves = expected_moves; let expected_moves = expected_moves;
let mut moves: Vec<Move> = board.gen_moves().into_iter().collect(); let mut moves: Vec<Move> = board.gen_moves(MoveGenType::Legal).into_iter().collect();
moves.sort_unstable(); moves.sort_unstable();
let moves = moves; let moves = moves;