| import os |
| import random |
| from collections import Counter |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import Dataset, DataLoader |
| from tqdm import tqdm |
| import glob |
|
|
| MODEL_FILE = "AgGPT21.pt" |
| DATA_FOLDER = "training_corpora/" |
|
|
| SEED = 42 |
| random.seed(SEED) |
| torch.manual_seed(SEED) |
|
|
| SEQ_LEN = 64 |
| STRIDE = 64 |
| EMBED_SIZE = 128 |
| HIDDEN_SIZE = 128 |
| NUM_LAYERS = 1 |
| DROPOUT = 0.2 |
|
|
| BATCH_SIZE = 8 |
| EPOCHS = 6 |
| LR = 2e-3 |
| WEIGHT_DECAY = 1e-4 |
| CLIP_NORM = 1.0 |
|
|
| GENERATE_LENGTH = 200 |
| DATA_PERCENT = 0.1 |
| MAX_TOKENS = 1_000_000 |
| MAX_VOCAB = 30000 |
|
|
| TEMPERATURE = 0.9 |
| TOP_K = 50 |
| TOP_P = 0.9 |
|
|
| if torch.backends.mps.is_available(): |
| DEVICE = torch.device("mps") |
| elif torch.cuda.is_available(): |
| DEVICE = torch.device("cuda") |
| else: |
| DEVICE = torch.device("cpu") |
|
|
| def build_vocab_and_ids(folder_path, percent=1.0, max_tokens=None, max_vocab=None): |
| """Build vocabulary and token IDs from all text files in a folder.""" |
| all_words = [] |
| |
| |
| txt_files = glob.glob(os.path.join(folder_path, "*.txt")) |
| if not txt_files: |
| raise FileNotFoundError(f"No .txt files found in {folder_path}") |
| |
| print(f"Found {len(txt_files)} training files") |
| |
| |
| if percent < 1.0: |
| num_files_to_use = max(1, int(len(txt_files) * percent)) |
| txt_files = txt_files[:num_files_to_use] |
| print(f"Using {percent*100}% of files: {num_files_to_use}/{len(glob.glob(os.path.join(folder_path, '*.txt')))} files") |
| |
| |
| for file_path in sorted(txt_files): |
| print(f"Reading {os.path.basename(file_path)}...") |
| with open(file_path, "r", encoding="utf-8") as f: |
| text = f.read().lower() |
| |
| words = [w for w in text.split() if w] |
| all_words.extend(words) |
| |
| print(f"Total words loaded: {len(all_words):,}") |
| |
| if max_tokens is not None: |
| all_words = all_words[:max_tokens] |
| print(f"Truncated to max_tokens: {len(all_words):,} words") |
| |
| counts = Counter(all_words) |
| if max_vocab is not None: |
| keep = max(1, max_vocab - 1) |
| common = [w for w, _ in counts.most_common(keep) if w != "<unk>"] |
| vocab = ["<unk>"] + common |
| else: |
| vocab = ["<unk>"] + [w for w in counts if w != "<unk>"] |
| |
| stoi = {w: i for i, w in enumerate(vocab)} |
| itos = {i: w for w, i in stoi.items()} |
| ids = [stoi.get(w, 0) for w in all_words] |
| |
| print(f"Vocabulary size: {len(vocab):,}") |
| return vocab, stoi, itos, ids |
|
|
| class WordDataset(Dataset): |
| def __init__(self, ids, seq_len, stride=None): |
| self.ids = ids |
| self.seq_len = seq_len |
| self.stride = stride or seq_len |
| self.n = max(0, (len(self.ids) - self.seq_len - 1) // self.stride + 1) |
| def __len__(self): |
| return self.n |
| def __getitem__(self, idx): |
| start = idx * self.stride |
| x = torch.tensor(self.ids[start:start + self.seq_len], dtype=torch.long) |
| y = torch.tensor(self.ids[start + 1:start + self.seq_len + 1], dtype=torch.long) |
| return x, y |
|
|
| class WordRNN(nn.Module): |
| def __init__(self, vocab_size, embed_size=EMBED_SIZE, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, dropout=DROPOUT): |
| super().__init__() |
| self.embed = nn.Embedding(vocab_size, embed_size) |
| self.drop = nn.Dropout(dropout) |
| self.gru = nn.GRU(embed_size, hidden_size, num_layers=num_layers, batch_first=True) |
| self.proj = None |
| if hidden_size != embed_size: |
| self.proj = nn.Linear(hidden_size, embed_size, bias=False) |
| out_size = embed_size if self.proj else hidden_size |
| self.fc = nn.Linear(out_size, vocab_size, bias=False) |
| self.fc.weight = self.embed.weight |
| def forward(self, x, hidden=None): |
| e = self.drop(self.embed(x)) |
| out, h = self.gru(e, hidden) |
| out = self.drop(out) |
| if self.proj is not None: |
| out = self.proj(out) |
| logits = self.fc(out) |
| return logits, h |
|
|
| def count_parameters(model): |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
| def evaluate(model, dataloader, device, use_amp): |
| model.eval() |
| criterion = nn.CrossEntropyLoss(ignore_index=0) |
| total_loss = 0.0 |
| with torch.no_grad(): |
| for x, y in dataloader: |
| x = x.to(device) |
| y = y.to(device) |
| with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp): |
| logits, _ = model(x) |
| loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1)) |
| total_loss += loss.item() |
| return total_loss / max(1, len(dataloader)) |
|
|
| def train(model, train_loader, val_loader, epochs, lr, device, weight_decay, clip_norm, stoi, itos): |
| model.to(device) |
| opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) |
| criterion = nn.CrossEntropyLoss(ignore_index=0) |
| use_amp = device.type in {"mps", "cuda"} |
| best_val = float("inf") |
| patience = 2 |
| epochs_no_improve = 0 |
| print(f"Train batches per epoch: {len(train_loader)} | Val batches: {len(val_loader)}") |
| epoch_bar = tqdm(range(epochs), desc="Epochs") |
| for epoch in epoch_bar: |
| model.train() |
| total_loss = 0.0 |
| batch_bar = tqdm(train_loader, desc=f"Train {epoch+1}/{epochs}", leave=False) |
| for x, y in batch_bar: |
| x = x.to(device) |
| y = y.to(device) |
| opt.zero_grad() |
| with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp): |
| logits, _ = model(x) |
| loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1)) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm) |
| opt.step() |
| total_loss += loss.item() |
| batch_bar.close() |
| train_loss = total_loss / max(1, len(train_loader)) |
| val_loss = evaluate(model, val_loader, device, use_amp) |
| epoch_bar.set_postfix(train=f"{train_loss:.4f}", val=f"{val_loss:.4f}") |
| if val_loss < best_val - 1e-4: |
| best_val = val_loss |
| epochs_no_improve = 0 |
| torch.save({"model_state": model.state_dict(), "stoi": stoi, "itos": itos}, MODEL_FILE) |
| else: |
| epochs_no_improve += 1 |
| if epochs_no_improve >= patience: |
| print("Early stopping.") |
| break |
| ckpt = torch.load(MODEL_FILE, map_location=device) |
| model.load_state_dict(ckpt["model_state"]) |
| return model |
|
|
| def _sample_next_id(probs_1d, top_k=None, top_p=None, temperature=1.0, forbid_ids=None): |
| probs = probs_1d.clone() |
| if forbid_ids: |
| for i in forbid_ids: |
| if 0 <= i < probs.numel(): |
| probs[i] = 0 |
| if temperature != 1.0: |
| logits = torch.log(probs + 1e-9) / temperature |
| probs = torch.softmax(logits, dim=-1) |
| if probs.sum() <= 0: |
| probs = torch.ones_like(probs) |
| if forbid_ids: |
| for i in forbid_ids: |
| if 0 <= i < probs.numel(): |
| probs[i] = 0 |
| probs = probs / probs.sum() |
| if top_k is not None and top_k > 0: |
| k = min(top_k, probs.size(-1)) |
| values, indices = torch.topk(probs, k) |
| values = values / values.sum() |
| idx = indices[torch.multinomial(values, 1)] |
| return idx.item() |
| if top_p is not None and 0 < top_p < 1.0: |
| sorted_probs, sorted_indices = torch.sort(probs, descending=True) |
| cumulative = torch.cumsum(sorted_probs, dim=-1) |
| keep_mask = cumulative <= top_p |
| keep = int(keep_mask.nonzero()[-1].item()) + 1 if keep_mask.any() else 1 |
| sorted_probs = sorted_probs[:keep] |
| sorted_indices = sorted_indices[:keep] |
| sorted_probs = sorted_probs / sorted_probs.sum() |
| idx_pos = torch.multinomial(sorted_probs, 1) |
| return sorted_indices[idx_pos].item() |
| probs = probs / probs.sum() |
| return torch.multinomial(probs, 1).item() |
|
|
| def generate_text(model, stoi, itos, prompt, length=GENERATE_LENGTH, seq_len=SEQ_LEN, device=DEVICE, temperature=TEMPERATURE, top_k=TOP_K, top_p=TOP_P): |
| model.to(device) |
| model.eval() |
| words = prompt.lower().split() |
| ids = [stoi.get(w, 0) for w in words] |
| context = ids[-seq_len:] if len(ids) >= seq_len else [0] * (seq_len - len(ids)) + ids |
| input_ids = torch.tensor(context, dtype=torch.long).unsqueeze(0).to(device) |
| hidden = None |
| generated = words.copy() |
| use_amp = device.type in {"mps", "cuda"} |
| with torch.no_grad(): |
| gen_bar = tqdm(range(length), desc="Generating", leave=False) |
| for _ in gen_bar: |
| with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp): |
| logits, hidden = model(input_ids, hidden) |
| probs = torch.softmax(logits[:, -1, :], dim=-1).squeeze(0) |
| next_id = _sample_next_id(probs, top_k=top_k, top_p=top_p, temperature=temperature, forbid_ids=[0]) |
| next_word = itos.get(next_id, "<unk>") |
| generated.append(next_word) |
| input_ids = torch.tensor([[next_id]], dtype=torch.long).to(device) |
| return " ".join(generated) |
|
|
| if __name__ == "__main__": |
| if os.path.exists(MODEL_FILE): |
| ckpt = torch.load(MODEL_FILE, map_location=DEVICE) |
| stoi = ckpt["stoi"] |
| itos = ckpt["itos"] |
| model = WordRNN(len(stoi)) |
| model.load_state_dict(ckpt["model_state"]) |
| print(f"Loaded model {MODEL_FILE} | device={DEVICE} | params={count_parameters(model):,}") |
| else: |
| if not os.path.exists(DATA_FOLDER): |
| raise FileNotFoundError(f"Training folder not found: {DATA_FOLDER}") |
| vocab, stoi, itos, ids = build_vocab_and_ids(DATA_FOLDER, percent=DATA_PERCENT, max_tokens=MAX_TOKENS, max_vocab=MAX_VOCAB) |
| print(f"Vocab size: {len(vocab):,} | Tokens used: {len(ids):,} | device={DEVICE}") |
| val_tokens = max(SEQ_LEN * 5, int(0.05 * len(ids))) |
| train_ids = ids[:-val_tokens] |
| val_ids = ids[-val_tokens:] |
| train_dataset = WordDataset(train_ids, SEQ_LEN, stride=STRIDE) |
| val_dataset = WordDataset(val_ids, SEQ_LEN, stride=STRIDE) |
| train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) |
| val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True) |
| model = WordRNN(len(vocab)) |
| print(f"Model params: {count_parameters(model):,}") |
| model = train(model, train_loader, val_loader, EPOCHS, LR, DEVICE, WEIGHT_DECAY, CLIP_NORM, stoi, itos) |
| torch.save({"model_state": model.state_dict(), "stoi": stoi, "itos": itos}, MODEL_FILE) |
| print(f"Saved {MODEL_FILE}") |
| |
| print("\n=== AgGPT-21 Demo ===") |
| prompts = ["hello world", "how are you", "once upon a time", "tell me about"] |
| for p in prompts: |
| print(f"\nPrompt: {p}") |
| print(f"Generated: {generate_text(model, stoi, itos, p)}") |
| print("\nTraining complete! Use chat.py for interactive conversation.") |