Compare commits

...

5 Commits

6 changed files with 239 additions and 30 deletions

View File

@ -165,6 +165,8 @@ pub struct MsgBestmove {
pub pv: Vec<Move>,
/// Evaluation of the position
pub eval: SearchEval,
/// Extra information (displayed as `info string`).
pub info: Vec<String>,
}
/// Interface messages that may be received by main's channel.

View File

@ -11,7 +11,7 @@ You should have received a copy of the GNU General Public License along with che
Copyright © 2024 dogeystamp <dogeystamp@disroot.org>
*/
//! Position evaluation.
//! Static position evaluation (hand-crafted eval).
use crate::prelude::*;
use core::cmp::{max, min};

View File

@ -19,6 +19,7 @@ use std::str::FromStr;
pub mod coordination;
pub mod eval;
pub mod nnue;
pub mod fen;
mod hash;
pub mod movegen;
@ -516,9 +517,8 @@ mod ringptr_tests {
/// Ring-buffer of previously seen hashes, used to avoid draw by repetition.
///
/// Only stores at most `HISTORY_SIZE` plies, since most cases of repetition happen recently.
/// Technically, it should be 100 plies because of the 50-move rule.
#[derive(Default, Clone, Copy, Debug)]
/// Only stores at most `HISTORY_SIZE` plies.
#[derive(Clone, Copy, Debug)]
struct BoardHistory {
hashes: [Zobrist; HISTORY_SIZE],
/// Index of the start of the history in the buffer
@ -527,6 +527,17 @@ struct BoardHistory {
ptr_end: RingPtr<HISTORY_SIZE>,
}
impl Default for BoardHistory {
fn default() -> Self {
BoardHistory {
// rust can't derive this
hashes: [Zobrist::default(); HISTORY_SIZE],
ptr_start: Default::default(),
ptr_end: Default::default(),
}
}
}
impl PartialEq for BoardHistory {
/// Always equal, since comparing two boards with different histories shouldn't matter.
fn eq(&self, _other: &Self) -> bool {
@ -539,7 +550,7 @@ impl Eq for BoardHistory {}
/// Size in plies of the board history.
///
/// Actual capacity is one less than this.
const HISTORY_SIZE: usize = 15;
const HISTORY_SIZE: usize = 100;
impl BoardHistory {
/// Counts occurences of this hash in the history.

View File

@ -51,8 +51,9 @@ macro_rules! ignore {
/// UCI engine metadata query.
fn cmd_uci() -> String {
let str = "id name chess_inator\n\
id author dogeystamp\n\
uciok";
id author dogeystamp\n\
option name NNUETrainInfo type check default false\n\
uciok";
str.into()
}
@ -111,15 +112,14 @@ fn cmd_position(mut tokens: std::str::SplitWhitespace<'_>, state: &mut MainState
/// Play the game.
fn cmd_go(mut tokens: std::str::SplitWhitespace<'_>, state: &mut MainState) {
let mut wtime = 0;
let mut btime = 0;
let mut wtime: Option<u64> = None;
let mut btime: Option<u64> = None;
let mut movetime: Option<u64> = None;
macro_rules! set_time {
($color: expr, $var: ident) => {
($var: ident) => {
if let Some(time) = tokens.next() {
if let Ok(time) = time.parse::<u64>() {
$var = time;
}
$var = time.parse::<u64>().ok();
}
};
}
@ -127,35 +127,40 @@ fn cmd_go(mut tokens: std::str::SplitWhitespace<'_>, state: &mut MainState) {
while let Some(token) = tokens.next() {
match token {
"wtime" => {
set_time!(Color::White, wtime)
set_time!(wtime)
}
"btime" => {
set_time!(Color::Black, btime)
set_time!(btime)
}
"movetime" => {
set_time!(movetime)
}
_ => ignore!(),
}
}
let (mut ourtime_ms, theirtime_ms) = if state.board.get_turn() == Color::White {
let (ourtime_ms, theirtime_ms) = if state.board.get_turn() == Color::White {
(wtime, btime)
} else {
(btime, wtime)
};
if ourtime_ms == 0 {
ourtime_ms = 300_000
}
let time_lims = if let Some(movetime) = movetime {
TimeLimits::from_movetime(movetime)
} else {
TimeLimits::from_ourtime_theirtime(
ourtime_ms.unwrap_or(300_000),
theirtime_ms.unwrap_or(300_000),
eval_metrics(&state.board),
)
};
state
.tx_engine
.send(MsgToEngine::Go(Box::new(GoMessage {
board: state.board,
config: state.config,
time_lims: TimeLimits::from_ourtime_theirtime(
ourtime_ms,
theirtime_ms,
eval_metrics(&state.board),
),
time_lims,
})))
.unwrap();
}
@ -202,6 +207,50 @@ fn cmd_eval(mut _tokens: std::str::SplitWhitespace<'_>, state: &mut MainState) {
println!("- total: {}", res.total_eval);
}
fn match_true_false(s: &str) -> Option<bool> {
match s {
"true" => Some(true),
"false" => Some(false),
_ => None,
}
}
/// Set engine options via UCI.
fn cmd_setoption(mut tokens: std::str::SplitWhitespace<'_>, state: &mut MainState) {
while let Some(token) = tokens.next() {
fn get_val(mut tokens: std::str::SplitWhitespace<'_>) -> Option<String> {
if let Some("value") = tokens.next() {
if let Some(value) = tokens.next() {
return Some(value.to_string());
}
}
None
}
match token {
"name" => {
if let Some(name) = tokens.next() {
match name {
"NNUETrainInfo" => {
if let Some(value) = get_val(tokens) {
if let Some(value) = match_true_false(&value) {
state.config.nnue_train_info = value;
}
}
}
_ => {
println!("info string Unknown option: {}", name)
}
}
}
}
_ => ignore!(),
}
break;
}
}
/// Root UCI parser.
fn cmd_root(mut tokens: std::str::SplitWhitespace<'_>, state: &mut MainState) {
while let Some(token) = tokens.next() {
@ -237,6 +286,9 @@ fn cmd_root(mut tokens: std::str::SplitWhitespace<'_>, state: &mut MainState) {
state.tx_engine.send(MsgToEngine::Stop).unwrap();
}
}
"setoption" => {
cmd_setoption(tokens, state);
}
// non-standard command.
"eval" => {
cmd_eval(tokens, state);
@ -269,6 +321,10 @@ fn outp_bestmove(bestmove: MsgBestmove) {
panic!("info string ERROR: stopped search")
}
}
for line in bestmove.info {
println!("info string {line}");
}
match chosen {
Some(mv) => println!("bestmove {}", mv.to_uci_algebraic()),
None => println!("bestmove 0000"),
@ -310,8 +366,23 @@ fn task_engine(tx_main: Sender<MsgToMain>, rx_engine: Receiver<MsgToEngine>) {
state.config = msg_box.config;
state.time_lims = msg_box.time_lims;
let (pv, eval) = best_line(&mut board, &mut state);
let mut info: Vec<String> = Vec::new();
if state.config.nnue_train_info {
let is_quiet = chess_inator::search::is_quiescent_position(&board, eval);
let is_quiet = if is_quiet {"quiet"} else {"non-quiet"};
let board_tensor = chess_inator::nnue::InputTensor::from_board(&board);
info.push(format!("NNUETrainInfo {} {}", is_quiet, {board_tensor}))
}
tx_main
.send(MsgToMain::Bestmove(MsgBestmove { pv, eval }))
.send(MsgToMain::Bestmove(MsgBestmove {
pv,
eval,
info,
}))
.unwrap();
}
MsgToEngine::Stop => {}

92
src/nnue.rs Normal file
View File

@ -0,0 +1,92 @@
/*
This file is part of chess_inator.
chess_inator is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 3 of the License, or (at your option) any later version.
chess_inator is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
You should have received a copy of the GNU General Public License along with chess_inator. If not, see https://www.gnu.org/licenses/.
Copyright © 2024 dogeystamp <dogeystamp@disroot.org>
*/
//! Static position evaluation (neural network based eval).
//!
//! # Neural net architecture
//!
//! The NNUE has the following layers:
//!
//! * Input (board features)
//! * Hidden layer / accumulator (N neurons)
//! * Output layer (Single neuron)
//!
//! The input layer is a multi-hot binary tensor that represents the board. It is a product of
//! color (2), piece type (6) and piece position (64), giving a total of 768 elements representing
//! for example "is there a _white_, _pawn_ at _e4_?". This information is not enough to represent
//! the board, but is enough for static evaluation purposes. Our NNUE is only expected to run on
//! quiescent positions, and our traditional minmax algorithm will take care of any exchanges, en
//! passant, and other rules that can be mechanically applied.
//!
//! In the engine, the input layer is imaginary. Because of the nature of NNUE (efficiently
//! updatable neural network), we only store the hidden layer's state, and whenever we want to flip
//! a bit in the input layer, we directly add/subtract the corresponding weight from the hidden
//! layer.
use crate::prelude::*;
use std::fmt::Display;
/// Size of the input feature tensor.
pub const INP_TENSOR_SIZE: usize = N_COLORS * N_PIECES * N_SQUARES;
#[derive(Debug, PartialEq, Eq)]
pub struct InputTensor([bool; INP_TENSOR_SIZE]);
/// Input tensor for the NNUE.
///
/// Note that this tensor does not exist at runtime, only during training.
impl InputTensor {
/// Calculate index within the input tensor of a piece/color/square combination.
pub fn idx(pc: ColPiece, sq: Square) -> usize {
let col = pc.col as usize;
let pc = pc.pc as usize;
let sq = sq.0 as usize;
let ret = col * (N_PIECES * N_SQUARES) + pc * (N_SQUARES) + sq;
debug_assert!((0..INP_TENSOR_SIZE).contains(&ret));
ret
}
/// Create the tensor from a board.
pub fn from_board(board: &Board) -> Self {
let mut tensor = [false; INP_TENSOR_SIZE];
for sq in Board::squares() {
if let Some(pc) = board.get_piece(sq) {
let idx = Self::idx(pc, sq);
tensor[idx] = true;
}
}
InputTensor(tensor)
}
}
impl Display for InputTensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let str = String::from_iter(self.0.map(|x| if x { '1' } else { '0' }));
write!(f, "{}", str)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_to_binary_tensor() {
// more of a sanity check than a test
let board = Board::from_fen("8/8/8/8/8/8/8/1b6 w - - 0 1").unwrap();
let tensor = InputTensor::from_board(&board);
let mut expected = [false; INP_TENSOR_SIZE];
expected[INP_TENSOR_SIZE / N_COLORS + 1 + N_SQUARES] = true;
assert_eq!(tensor.0, expected);
}
}

View File

@ -128,6 +128,8 @@ pub struct SearchConfig {
pub enable_trans_table: bool,
/// Transposition table size (2^n where this is n)
pub transposition_size: usize,
/// Print machine-readable information about the position during NNUE training data generation.
pub nnue_train_info: bool,
}
impl Default for SearchConfig {
@ -139,6 +141,7 @@ impl Default for SearchConfig {
contempt: 0,
enable_trans_table: true,
transposition_size: 24,
nnue_train_info: false,
}
}
}
@ -450,11 +453,7 @@ impl TimeLimits {
/// Make time limits based on wtime, btime (but color-independent).
///
/// Also takes in eval metrics, for instance to avoid wasting too much time in the opening.
pub fn from_ourtime_theirtime(
ourtime_ms: u64,
_theirtime_ms: u64,
eval: EvalMetrics,
) -> TimeLimits {
pub fn from_ourtime_theirtime(ourtime_ms: u64, _theirtime_ms: u64, eval: EvalMetrics) -> Self {
// hard timeout (max)
let mut hard_ms = 100_000;
// soft timeout (default max)
@ -488,6 +487,16 @@ impl TimeLimits {
soft: Some(soft_limit),
}
}
/// Make time limit based on an exact hard limit.
pub fn from_movetime(movetime_ms: u64) -> Self {
let hard_limit = Instant::now() + Duration::from_millis(movetime_ms);
TimeLimits {
hard: Some(hard_limit),
soft: None,
}
}
}
/// Helper type to avoid retyping the same arguments into every function prototype.
@ -540,3 +549,27 @@ pub fn best_move(board: &mut Board, engine_state: &mut EngineState) -> Option<Mo
let (line, _eval) = best_line(board, engine_state);
line.last().copied()
}
/// Utility for NNUE training set generation to determine if a position is quiet or not.
///
/// Our definition of "quiet" is that there are no checks, and the static and quiescence search
/// evaluations are similar. (See https://arxiv.org/html/2412.17948v1.)
///
/// It is the caller's responsibility to get the search evaluation and pass it to this function.
pub fn is_quiescent_position(board: &Board, eval: SearchEval) -> bool {
// max centipawn value difference to call "similar"
const THRESHOLD: EvalInt = 170;
if board.is_check(board.turn) {
return false;
}
if matches!(eval, SearchEval::Checkmate(_)) {
return false;
}
// white perspective
let abs_eval = EvalInt::from(eval) * EvalInt::from(board.turn.sign());
(board.eval() - EvalInt::from(abs_eval)).abs() <= THRESHOLD.abs()
}