feat: torch data loader

This commit is contained in:
dogeystamp 2024-12-30 22:53:59 -05:00
parent 4ba02e9963
commit cbad993a0a
No known key found for this signature in database
5 changed files with 85 additions and 7 deletions

1
nnue/.gitignore vendored
View File

@ -1,3 +1,4 @@
batches/ batches/
venv/ venv/
train_data/ train_data/
__pycache__/

View File

@ -6,15 +6,19 @@ The network is trained on both self-play games, and its games on Lichess.
Both of these sources provide games in PGN format. Both of these sources provide games in PGN format.
This folder includes the following scripts: This folder includes the following scripts:
- `batch_pgn_data.py`: Combine and convert big PGN files into small chunked files. - `s1_batch_pgn_data.py`: Combine and convert big PGN files into small chunked files.
- `process_pgn_data.py`: Convert PGN data into a format suitable for training. - `s2_process_pgn_data.py`: Convert PGN data into a format suitable for training.
Example training pipeline: Example training pipeline:
```bash ```bash
# chunk all the PGN files in `games/`. outputs by default to `batches/batch%d.pgn`. # chunk all the PGN files in `games/`. outputs by default to `batches/batch%d.pgn`.
./batch_pgn_data.py games/*.pgn ./s1_batch_pgn_data.py games/*.pgn
# analyze batches 0 to 20 to turn them into training data. outputs by default to train_data/batch%d.tsv.gz. # analyze batches to turn them into training data. outputs by default to train_data/batch%d.tsv.gz.
# set max-workers to the number of hardware threads / cores you have. # set max-workers to the number of hardware threads / cores you have.
./process_pgn_data.py --engine ../target/release/chess_inator --max-workers 8 batches/batch{0..20}.pgn # this is the longest part.
./s2_process_pgn_data.py --engine ../target/release/chess_inator --max-workers 8 batches/batch*.pgn
# combine all processed data into a single training set file.
zcat train_data/*.tsv.gz | gzip > combined_training.tsv.gz
``` ```

View File

@ -18,8 +18,6 @@ import itertools
from pathlib import Path from pathlib import Path
"""Games to include per file in output."""
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("files", nargs="+", type=Path) parser.add_argument("files", nargs="+", type=Path)
parser.add_argument("--batch-size", type=int, help="Number of games to save in each output file. Set this to two to four times the amount of concurrent workers used in the processing step.", default=8) parser.add_argument("--batch-size", type=int, help="Number of games to save in each output file. Set this to two to four times the amount of concurrent workers used in the processing step.", default=8)

75
nnue/s3_train_neural_net.py Executable file
View File

@ -0,0 +1,75 @@
#!/usr/bin/env python
"""Train the NNUE weights."""
import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from dataclasses import dataclass
################################
################################
## Data loading / parsing
################################
################################
@dataclass
class Position:
"""Single board position."""
fen: str
"""Normal board representation."""
board: torch.Tensor
"""Multi-hot board representation."""
cp_eval: np.double
"""Centipawn evaluation (white perspective)."""
expected_points: np.double
"""
Points expected to be gained for white from the game, based on centipawn evaluation.
- 0: black win
- 0.5: draw
- 1: white win
"""
def sigmoid(x):
"""Calculate sigmoid of `x`, using scaling constant `K`."""
K = 150
return 1 / (1 + np.exp(-K * x / 400))
class ChessPositionDataset(Dataset):
def __init__(self, data_file: Path):
self.data = pd.read_csv(data_file, delimiter="\t")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
row = self.data.iloc[idx]
eval = np.double(row.iloc[2])
return Position(
fen=row.iloc[0],
board=torch.as_tensor([1 if c == "1" else 0 for c in row.iloc[1]]),
cp_eval=eval,
expected_points=sigmoid(eval/100),
)
if __name__ == "__main__":
full_dataset = ChessPositionDataset(Path("combined_training.tsv.gz"))
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [0.8, 0.2])
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)