Compare commits
5 Commits
caa3bc454c
...
87501b5c94
Author | SHA1 | Date | |
---|---|---|---|
87501b5c94 | |||
47a600cd80 | |||
e27e18e482 | |||
ede46552fe | |||
fc8eab4d4b |
@ -165,6 +165,8 @@ pub struct MsgBestmove {
|
||||
pub pv: Vec<Move>,
|
||||
/// Evaluation of the position
|
||||
pub eval: SearchEval,
|
||||
/// Extra information (displayed as `info string`).
|
||||
pub info: Vec<String>,
|
||||
}
|
||||
|
||||
/// Interface messages that may be received by main's channel.
|
||||
|
@ -11,7 +11,7 @@ You should have received a copy of the GNU General Public License along with che
|
||||
Copyright © 2024 dogeystamp <dogeystamp@disroot.org>
|
||||
*/
|
||||
|
||||
//! Position evaluation.
|
||||
//! Static position evaluation (hand-crafted eval).
|
||||
|
||||
use crate::prelude::*;
|
||||
use core::cmp::{max, min};
|
||||
|
19
src/lib.rs
19
src/lib.rs
@ -19,6 +19,7 @@ use std::str::FromStr;
|
||||
|
||||
pub mod coordination;
|
||||
pub mod eval;
|
||||
pub mod nnue;
|
||||
pub mod fen;
|
||||
mod hash;
|
||||
pub mod movegen;
|
||||
@ -516,9 +517,8 @@ mod ringptr_tests {
|
||||
|
||||
/// Ring-buffer of previously seen hashes, used to avoid draw by repetition.
|
||||
///
|
||||
/// Only stores at most `HISTORY_SIZE` plies, since most cases of repetition happen recently.
|
||||
/// Technically, it should be 100 plies because of the 50-move rule.
|
||||
#[derive(Default, Clone, Copy, Debug)]
|
||||
/// Only stores at most `HISTORY_SIZE` plies.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct BoardHistory {
|
||||
hashes: [Zobrist; HISTORY_SIZE],
|
||||
/// Index of the start of the history in the buffer
|
||||
@ -527,6 +527,17 @@ struct BoardHistory {
|
||||
ptr_end: RingPtr<HISTORY_SIZE>,
|
||||
}
|
||||
|
||||
impl Default for BoardHistory {
|
||||
fn default() -> Self {
|
||||
BoardHistory {
|
||||
// rust can't derive this
|
||||
hashes: [Zobrist::default(); HISTORY_SIZE],
|
||||
ptr_start: Default::default(),
|
||||
ptr_end: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for BoardHistory {
|
||||
/// Always equal, since comparing two boards with different histories shouldn't matter.
|
||||
fn eq(&self, _other: &Self) -> bool {
|
||||
@ -539,7 +550,7 @@ impl Eq for BoardHistory {}
|
||||
/// Size in plies of the board history.
|
||||
///
|
||||
/// Actual capacity is one less than this.
|
||||
const HISTORY_SIZE: usize = 15;
|
||||
const HISTORY_SIZE: usize = 100;
|
||||
|
||||
impl BoardHistory {
|
||||
/// Counts occurences of this hash in the history.
|
||||
|
111
src/main.rs
111
src/main.rs
@ -51,8 +51,9 @@ macro_rules! ignore {
|
||||
/// UCI engine metadata query.
|
||||
fn cmd_uci() -> String {
|
||||
let str = "id name chess_inator\n\
|
||||
id author dogeystamp\n\
|
||||
uciok";
|
||||
id author dogeystamp\n\
|
||||
option name NNUETrainInfo type check default false\n\
|
||||
uciok";
|
||||
str.into()
|
||||
}
|
||||
|
||||
@ -111,15 +112,14 @@ fn cmd_position(mut tokens: std::str::SplitWhitespace<'_>, state: &mut MainState
|
||||
|
||||
/// Play the game.
|
||||
fn cmd_go(mut tokens: std::str::SplitWhitespace<'_>, state: &mut MainState) {
|
||||
let mut wtime = 0;
|
||||
let mut btime = 0;
|
||||
let mut wtime: Option<u64> = None;
|
||||
let mut btime: Option<u64> = None;
|
||||
let mut movetime: Option<u64> = None;
|
||||
|
||||
macro_rules! set_time {
|
||||
($color: expr, $var: ident) => {
|
||||
($var: ident) => {
|
||||
if let Some(time) = tokens.next() {
|
||||
if let Ok(time) = time.parse::<u64>() {
|
||||
$var = time;
|
||||
}
|
||||
$var = time.parse::<u64>().ok();
|
||||
}
|
||||
};
|
||||
}
|
||||
@ -127,35 +127,40 @@ fn cmd_go(mut tokens: std::str::SplitWhitespace<'_>, state: &mut MainState) {
|
||||
while let Some(token) = tokens.next() {
|
||||
match token {
|
||||
"wtime" => {
|
||||
set_time!(Color::White, wtime)
|
||||
set_time!(wtime)
|
||||
}
|
||||
"btime" => {
|
||||
set_time!(Color::Black, btime)
|
||||
set_time!(btime)
|
||||
}
|
||||
"movetime" => {
|
||||
set_time!(movetime)
|
||||
}
|
||||
_ => ignore!(),
|
||||
}
|
||||
}
|
||||
|
||||
let (mut ourtime_ms, theirtime_ms) = if state.board.get_turn() == Color::White {
|
||||
let (ourtime_ms, theirtime_ms) = if state.board.get_turn() == Color::White {
|
||||
(wtime, btime)
|
||||
} else {
|
||||
(btime, wtime)
|
||||
};
|
||||
|
||||
if ourtime_ms == 0 {
|
||||
ourtime_ms = 300_000
|
||||
}
|
||||
let time_lims = if let Some(movetime) = movetime {
|
||||
TimeLimits::from_movetime(movetime)
|
||||
} else {
|
||||
TimeLimits::from_ourtime_theirtime(
|
||||
ourtime_ms.unwrap_or(300_000),
|
||||
theirtime_ms.unwrap_or(300_000),
|
||||
eval_metrics(&state.board),
|
||||
)
|
||||
};
|
||||
|
||||
state
|
||||
.tx_engine
|
||||
.send(MsgToEngine::Go(Box::new(GoMessage {
|
||||
board: state.board,
|
||||
config: state.config,
|
||||
time_lims: TimeLimits::from_ourtime_theirtime(
|
||||
ourtime_ms,
|
||||
theirtime_ms,
|
||||
eval_metrics(&state.board),
|
||||
),
|
||||
time_lims,
|
||||
})))
|
||||
.unwrap();
|
||||
}
|
||||
@ -202,6 +207,50 @@ fn cmd_eval(mut _tokens: std::str::SplitWhitespace<'_>, state: &mut MainState) {
|
||||
println!("- total: {}", res.total_eval);
|
||||
}
|
||||
|
||||
fn match_true_false(s: &str) -> Option<bool> {
|
||||
match s {
|
||||
"true" => Some(true),
|
||||
"false" => Some(false),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set engine options via UCI.
|
||||
fn cmd_setoption(mut tokens: std::str::SplitWhitespace<'_>, state: &mut MainState) {
|
||||
while let Some(token) = tokens.next() {
|
||||
fn get_val(mut tokens: std::str::SplitWhitespace<'_>) -> Option<String> {
|
||||
if let Some("value") = tokens.next() {
|
||||
if let Some(value) = tokens.next() {
|
||||
return Some(value.to_string());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
match token {
|
||||
"name" => {
|
||||
if let Some(name) = tokens.next() {
|
||||
match name {
|
||||
"NNUETrainInfo" => {
|
||||
if let Some(value) = get_val(tokens) {
|
||||
if let Some(value) = match_true_false(&value) {
|
||||
state.config.nnue_train_info = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
println!("info string Unknown option: {}", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => ignore!(),
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/// Root UCI parser.
|
||||
fn cmd_root(mut tokens: std::str::SplitWhitespace<'_>, state: &mut MainState) {
|
||||
while let Some(token) = tokens.next() {
|
||||
@ -237,6 +286,9 @@ fn cmd_root(mut tokens: std::str::SplitWhitespace<'_>, state: &mut MainState) {
|
||||
state.tx_engine.send(MsgToEngine::Stop).unwrap();
|
||||
}
|
||||
}
|
||||
"setoption" => {
|
||||
cmd_setoption(tokens, state);
|
||||
}
|
||||
// non-standard command.
|
||||
"eval" => {
|
||||
cmd_eval(tokens, state);
|
||||
@ -269,6 +321,10 @@ fn outp_bestmove(bestmove: MsgBestmove) {
|
||||
panic!("info string ERROR: stopped search")
|
||||
}
|
||||
}
|
||||
for line in bestmove.info {
|
||||
println!("info string {line}");
|
||||
}
|
||||
|
||||
match chosen {
|
||||
Some(mv) => println!("bestmove {}", mv.to_uci_algebraic()),
|
||||
None => println!("bestmove 0000"),
|
||||
@ -310,8 +366,23 @@ fn task_engine(tx_main: Sender<MsgToMain>, rx_engine: Receiver<MsgToEngine>) {
|
||||
state.config = msg_box.config;
|
||||
state.time_lims = msg_box.time_lims;
|
||||
let (pv, eval) = best_line(&mut board, &mut state);
|
||||
|
||||
let mut info: Vec<String> = Vec::new();
|
||||
if state.config.nnue_train_info {
|
||||
let is_quiet = chess_inator::search::is_quiescent_position(&board, eval);
|
||||
let is_quiet = if is_quiet {"quiet"} else {"non-quiet"};
|
||||
|
||||
let board_tensor = chess_inator::nnue::InputTensor::from_board(&board);
|
||||
|
||||
info.push(format!("NNUETrainInfo {} {}", is_quiet, {board_tensor}))
|
||||
}
|
||||
|
||||
tx_main
|
||||
.send(MsgToMain::Bestmove(MsgBestmove { pv, eval }))
|
||||
.send(MsgToMain::Bestmove(MsgBestmove {
|
||||
pv,
|
||||
eval,
|
||||
info,
|
||||
}))
|
||||
.unwrap();
|
||||
}
|
||||
MsgToEngine::Stop => {}
|
||||
|
92
src/nnue.rs
Normal file
92
src/nnue.rs
Normal file
@ -0,0 +1,92 @@
|
||||
/*
|
||||
|
||||
This file is part of chess_inator.
|
||||
|
||||
chess_inator is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 3 of the License, or (at your option) any later version.
|
||||
|
||||
chess_inator is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License along with chess_inator. If not, see https://www.gnu.org/licenses/.
|
||||
|
||||
Copyright © 2024 dogeystamp <dogeystamp@disroot.org>
|
||||
*/
|
||||
|
||||
//! Static position evaluation (neural network based eval).
|
||||
//!
|
||||
//! # Neural net architecture
|
||||
//!
|
||||
//! The NNUE has the following layers:
|
||||
//!
|
||||
//! * Input (board features)
|
||||
//! * Hidden layer / accumulator (N neurons)
|
||||
//! * Output layer (Single neuron)
|
||||
//!
|
||||
//! The input layer is a multi-hot binary tensor that represents the board. It is a product of
|
||||
//! color (2), piece type (6) and piece position (64), giving a total of 768 elements representing
|
||||
//! for example "is there a _white_, _pawn_ at _e4_?". This information is not enough to represent
|
||||
//! the board, but is enough for static evaluation purposes. Our NNUE is only expected to run on
|
||||
//! quiescent positions, and our traditional minmax algorithm will take care of any exchanges, en
|
||||
//! passant, and other rules that can be mechanically applied.
|
||||
//!
|
||||
//! In the engine, the input layer is imaginary. Because of the nature of NNUE (efficiently
|
||||
//! updatable neural network), we only store the hidden layer's state, and whenever we want to flip
|
||||
//! a bit in the input layer, we directly add/subtract the corresponding weight from the hidden
|
||||
//! layer.
|
||||
|
||||
use crate::prelude::*;
|
||||
use std::fmt::Display;
|
||||
|
||||
/// Size of the input feature tensor.
|
||||
pub const INP_TENSOR_SIZE: usize = N_COLORS * N_PIECES * N_SQUARES;
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct InputTensor([bool; INP_TENSOR_SIZE]);
|
||||
|
||||
/// Input tensor for the NNUE.
|
||||
///
|
||||
/// Note that this tensor does not exist at runtime, only during training.
|
||||
impl InputTensor {
|
||||
/// Calculate index within the input tensor of a piece/color/square combination.
|
||||
pub fn idx(pc: ColPiece, sq: Square) -> usize {
|
||||
let col = pc.col as usize;
|
||||
let pc = pc.pc as usize;
|
||||
let sq = sq.0 as usize;
|
||||
|
||||
let ret = col * (N_PIECES * N_SQUARES) + pc * (N_SQUARES) + sq;
|
||||
debug_assert!((0..INP_TENSOR_SIZE).contains(&ret));
|
||||
ret
|
||||
}
|
||||
|
||||
/// Create the tensor from a board.
|
||||
pub fn from_board(board: &Board) -> Self {
|
||||
let mut tensor = [false; INP_TENSOR_SIZE];
|
||||
for sq in Board::squares() {
|
||||
if let Some(pc) = board.get_piece(sq) {
|
||||
let idx = Self::idx(pc, sq);
|
||||
tensor[idx] = true;
|
||||
}
|
||||
}
|
||||
|
||||
InputTensor(tensor)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for InputTensor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let str = String::from_iter(self.0.map(|x| if x { '1' } else { '0' }));
|
||||
write!(f, "{}", str)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_to_binary_tensor() {
|
||||
// more of a sanity check than a test
|
||||
let board = Board::from_fen("8/8/8/8/8/8/8/1b6 w - - 0 1").unwrap();
|
||||
let tensor = InputTensor::from_board(&board);
|
||||
let mut expected = [false; INP_TENSOR_SIZE];
|
||||
expected[INP_TENSOR_SIZE / N_COLORS + 1 + N_SQUARES] = true;
|
||||
assert_eq!(tensor.0, expected);
|
||||
}
|
||||
}
|
@ -128,6 +128,8 @@ pub struct SearchConfig {
|
||||
pub enable_trans_table: bool,
|
||||
/// Transposition table size (2^n where this is n)
|
||||
pub transposition_size: usize,
|
||||
/// Print machine-readable information about the position during NNUE training data generation.
|
||||
pub nnue_train_info: bool,
|
||||
}
|
||||
|
||||
impl Default for SearchConfig {
|
||||
@ -139,6 +141,7 @@ impl Default for SearchConfig {
|
||||
contempt: 0,
|
||||
enable_trans_table: true,
|
||||
transposition_size: 24,
|
||||
nnue_train_info: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -450,11 +453,7 @@ impl TimeLimits {
|
||||
/// Make time limits based on wtime, btime (but color-independent).
|
||||
///
|
||||
/// Also takes in eval metrics, for instance to avoid wasting too much time in the opening.
|
||||
pub fn from_ourtime_theirtime(
|
||||
ourtime_ms: u64,
|
||||
_theirtime_ms: u64,
|
||||
eval: EvalMetrics,
|
||||
) -> TimeLimits {
|
||||
pub fn from_ourtime_theirtime(ourtime_ms: u64, _theirtime_ms: u64, eval: EvalMetrics) -> Self {
|
||||
// hard timeout (max)
|
||||
let mut hard_ms = 100_000;
|
||||
// soft timeout (default max)
|
||||
@ -488,6 +487,16 @@ impl TimeLimits {
|
||||
soft: Some(soft_limit),
|
||||
}
|
||||
}
|
||||
|
||||
/// Make time limit based on an exact hard limit.
|
||||
pub fn from_movetime(movetime_ms: u64) -> Self {
|
||||
let hard_limit = Instant::now() + Duration::from_millis(movetime_ms);
|
||||
|
||||
TimeLimits {
|
||||
hard: Some(hard_limit),
|
||||
soft: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper type to avoid retyping the same arguments into every function prototype.
|
||||
@ -540,3 +549,27 @@ pub fn best_move(board: &mut Board, engine_state: &mut EngineState) -> Option<Mo
|
||||
let (line, _eval) = best_line(board, engine_state);
|
||||
line.last().copied()
|
||||
}
|
||||
|
||||
/// Utility for NNUE training set generation to determine if a position is quiet or not.
|
||||
///
|
||||
/// Our definition of "quiet" is that there are no checks, and the static and quiescence search
|
||||
/// evaluations are similar. (See https://arxiv.org/html/2412.17948v1.)
|
||||
///
|
||||
/// It is the caller's responsibility to get the search evaluation and pass it to this function.
|
||||
pub fn is_quiescent_position(board: &Board, eval: SearchEval) -> bool {
|
||||
// max centipawn value difference to call "similar"
|
||||
const THRESHOLD: EvalInt = 170;
|
||||
|
||||
if board.is_check(board.turn) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if matches!(eval, SearchEval::Checkmate(_)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// white perspective
|
||||
let abs_eval = EvalInt::from(eval) * EvalInt::from(board.turn.sign());
|
||||
|
||||
(board.eval() - EvalInt::from(abs_eval)).abs() <= THRESHOLD.abs()
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user