| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import Dataset, DataLoader |
| import tiktoken |
| from tqdm import tqdm |
| import shutil |
| import math |
| from pathlib import Path |
| import re |
|
|
| from gpt_pytorch import GPTPyTorch |
|
|
| |
| TRAIN_SEQ_LEN = 256 |
| BATCH_SIZE = 12 |
| EPOCHS = 50 |
| LEARNING_RATE = 6e-6 |
| WEIGHT_DECAY = 0.01 |
| GRAD_CLIP = 1.0 |
| KEEP_LAST_EPOCHS = 3 |
| VAL_SPLIT_RATIO = 0.05 |
|
|
| |
| BASE_MODEL_PATH = Path("models/JiRack_H12_L6_V50257_D768_MSL8192_FF768x4.pt") |
| LAST_TRAINED_PATH = Path("models/JiRack_last_H12_L6_V50257_D768_MSL8192_FF768x4.pt") |
| BACKUP_DIR = Path("models/backups") |
| BACKUP_DIR.mkdir(exist_ok=True) |
|
|
| |
| RAW_PATH = Path("datasets/dialogues_text_clean.txt") |
| CLEAN_PATH = Path("datasets/dialogues_text_clean.txt") |
|
|
| force_clean = False |
| if not CLEAN_PATH.exists(): |
| print("Clean dataset not found. Performing initial cleaning...") |
| force_clean = True |
| else: |
| try: |
| if RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime: |
| print("Changes detected in the source dataset. Performing re-cleaning...") |
| force_clean = True |
| else: |
| print(f"Using existing clean dataset → {CLEAN_PATH}") |
| except FileNotFoundError: |
| print("File system synchronization error. Performing re-cleaning for safety...") |
| force_clean = True |
|
|
| if force_clean: |
| if not RAW_PATH.exists(): |
| raise FileNotFoundError(f"ERROR: Source file {RAW_PATH} not found. Check the path.") |
|
|
| print("Cleaning dataset from garbage (extra spaces, incorrect separators)...") |
| text = RAW_PATH.read_text(encoding="utf-8") |
|
|
| text = re.sub(r' {2,}', ' ', text) |
| text = text.replace(" \n", "\n").replace("\n ", "\n") |
|
|
| CLEAN_PATH.write_text(text, encoding="utf-8") |
| print(f"Dataset successfully cleaned and saved → {CLEAN_PATH}") |
|
|
| DATASET_PATH = CLEAN_PATH |
|
|
| OUTPUT_DIR = Path("build/fine_tuning_output") |
| MODEL_SAVE_NAME = "gpt_finetuned.pt" |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| |
| class TextDataset(Dataset): |
| def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, encoding_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO): |
| self.seq_len = seq_len |
| |
| |
| print(f"Loading tiktoken encoding '{encoding_name}' (small file auto-downloads on first run if needed)...") |
| self.enc = tiktoken.get_encoding(encoding_name) |
| |
| self.split_type = split_type |
|
|
| print(f"Loading text from {text_file} for {split_type} split...") |
| text = Path(text_file).read_text(encoding="utf-8") |
| tokens = self.enc.encode(text) |
|
|
| if len(tokens) < seq_len * 2: |
| raise ValueError("Text too short!") |
|
|
| all_inputs = [] |
| all_labels = [] |
|
|
| for i in range(0, len(tokens) - seq_len, seq_len): |
| all_inputs.append(tokens[i:i + seq_len]) |
| all_labels.append(tokens[i + 1:i + seq_len + 1]) |
|
|
| total_sequences = len(all_inputs) |
| val_size = int(total_sequences * val_ratio) |
| train_size = total_sequences - val_size |
|
|
| if self.split_type == 'train': |
| self.inputs = all_inputs[:train_size] |
| self.labels = all_labels[:train_size] |
| elif self.split_type == 'val': |
| self.inputs = all_inputs[train_size:] |
| self.labels = all_labels[train_size:] |
| else: |
| raise ValueError("Invalid split_type. Must be 'train' or 'val'.") |
|
|
| print(f"Created {len(self.inputs):,} sequences for {self.split_type} split.") |
|
|
| def __len__(self): |
| return len(self.inputs) |
|
|
| def __getitem__(self, idx): |
| return (torch.tensor(self.inputs[idx], dtype=torch.long), |
| torch.tensor(self.labels[idx], dtype=torch.long)) |
|
|
| |
| def evaluate(model, dataloader, criterion, device): |
| model.eval() |
| total_loss = 0.0 |
|
|
| with torch.no_grad(): |
| for inputs, targets in dataloader: |
| inputs, targets = inputs.to(device), targets.to(device) |
| logits, _ = model(inputs) |
| loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) |
| total_loss += loss.item() |
|
|
| avg_loss = total_loss / len(dataloader) |
| model.train() |
| return avg_loss |
|
|
| |
| def cleanup_old_epochs(keep_last=KEEP_LAST_EPOCHS): |
| epochs = sorted([p for p in OUTPUT_DIR.glob("epoch*") if p.is_dir()], |
| key=lambda x: int(x.name.replace("epoch", ""))) |
| for old in epochs[:-keep_last]: |
| if old.exists(): |
| shutil.rmtree(old) |
| print(f"Deleted old epoch: {old.name}") |
|
|
| |
| def train(): |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| print("Loading model...") |
| model = GPTPyTorch().to(device) |
|
|
| |
| load_kwargs = {"map_location": device, "weights_only": True} |
| if LAST_TRAINED_PATH.exists(): |
| print(f"Resuming training from last model: {LAST_TRAINED_PATH}") |
| model.load_state_dict(torch.load(LAST_TRAINED_PATH, **load_kwargs)) |
| elif BASE_MODEL_PATH.exists(): |
| print(f"Starting from base model: {BASE_MODEL_PATH}") |
| model.load_state_dict(torch.load(BASE_MODEL_PATH, **load_kwargs)) |
| else: |
| print("No models found — initializing from scratch") |
|
|
| model.train() |
|
|
| train_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, encoding_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO) |
| val_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, encoding_name="gpt2", split_type='val', val_ratio=VAL_SPLIT_RATIO) |
|
|
| train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) |
| val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True) |
|
|
| optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) |
| criterion = nn.CrossEntropyLoss() |
|
|
| total_steps = len(train_dataloader) * EPOCHS |
| print(f"\n=== STARTING LONG-TERM TRAINING ===") |
| print(f"Epochs: {EPOCHS} | Steps (Train): {total_steps} | Examples (Train): {len(train_dataset)}") |
|
|
| global_step = 0 |
| for epoch in range(1, EPOCHS + 1): |
| print(f"\n--- Epoch {epoch}/{EPOCHS} ---") |
| epoch_loss = 0.0 |
|
|
| with tqdm(train_dataloader, desc=f"Epoch {epoch} [TRAIN]", leave=False) as pbar: |
| for inputs, targets in pbar: |
| inputs, targets = inputs.to(device), targets.to(device) |
|
|
| optimizer.zero_grad() |
| logits, _ = model(inputs) |
| loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) |
| optimizer.step() |
|
|
| loss_val = loss.item() |
| epoch_loss += loss_val |
| global_step += 1 |
|
|
| pbar.set_postfix({ |
| "loss": f"{loss_val:.3f}", |
| "ppl": f"{math.exp(min(loss_val, 10)):.1f}", |
| "step": f"{global_step}/{total_steps}" |
| }) |
|
|
| avg_train_loss = epoch_loss / len(train_dataloader) |
| print(f" [TRAIN] Average loss: {avg_train_loss:.3f} | PPL: {math.exp(avg_train_loss):.1f}") |
|
|
| print(" [VALIDATION] Running evaluation...") |
| val_loss = evaluate(model, val_dataloader, criterion, device) |
| print(f" [VALIDATION] Average loss: {val_loss:.3f} | PPL: {math.exp(val_loss):.1f}") |
|
|
| epoch_dir = OUTPUT_DIR / f"epoch{epoch}" |
| epoch_dir.mkdir(exist_ok=True) |
| torch.save(model.state_dict(), epoch_dir / MODEL_SAVE_NAME) |
| print(f"Model saved: {epoch_dir / MODEL_SAVE_NAME}") |
| cleanup_old_epochs() |
|
|
| |
| final_dir = OUTPUT_DIR / "final" |
| final_dir.mkdir(exist_ok=True) |
| torch.save(model.state_dict(), final_dir / MODEL_SAVE_NAME) |
|
|
| if LAST_TRAINED_PATH.exists(): |
| backup_path = BACKUP_DIR / f"gpt_last_trained_backup_{int(os.path.getmtime(LAST_TRAINED_PATH))}.pt" |
| shutil.copy(LAST_TRAINED_PATH, backup_path) |
| print(f"Backup of previous model created → {backup_path.name}") |
|
|
| shutil.copy(final_dir / MODEL_SAVE_NAME, LAST_TRAINED_PATH) |
| print(f"Last trained model saved → {LAST_TRAINED_PATH}") |
|
|
| print(f"\nTRAINING COMPLETED! Model is ready:") |
| print(f" • For chat/inference: {final_dir / MODEL_SAVE_NAME}") |
| print(f" • For continued fine-tuning: {LAST_TRAINED_PATH}") |
|
|
| if __name__ == "__main__": |
| if not RAW_PATH.exists(): |
| print(f"ERROR: File {RAW_PATH} not found") |
| print("Place your text in datasets/dialogues_text.txt") |
| else: |
| train() |