NousAPI / best.py
FaiziRBLX's picture
Upload best.py with huggingface_hub
17edf27 verified
"""
Production-Grade Indonesian Conversational Language Model
Trained from scratch with Chain-of-Thought reasoning capability
Architecture: Decoder-only transformer with MQA/GQA, RoPE, SwiGLU, RMSNorm
Target: 15M-30M parameters, optimized for Google Colab Free tier
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
import json
import math
import random
import numpy as np
from typing import Optional, Tuple, List, Dict
from dataclasses import dataclass
import warnings
import argparse
import os
warnings.filterwarnings('ignore')
# ============================================================================
# CONFIGURATION
# ============================================================================
@dataclass
class ModelConfig:
"""Model architecture configuration"""
vocab_size: int = 30000
hidden_size: int = 384
num_layers: int = 12
num_attention_heads: int = 6
num_key_value_heads: int = 2 # GQA: 2 KV heads, MQA: 1 KV head
intermediate_size: int = 1024
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
attention_dropout: float = 0.1
residual_dropout: float = 0.1
initializer_range: float = 0.02
use_cache: bool = False
pad_token_id: int = 0
bos_token_id: int = 1
eos_token_id: int = 2
tie_word_embeddings: bool = True
def __post_init__(self):
assert self.hidden_size % self.num_attention_heads == 0
assert self.num_attention_heads % self.num_key_value_heads == 0
@dataclass
class TrainingConfig:
"""Training hyperparameters"""
dataset_path: str = "indonesian_cot_dataset.jsonl"
output_dir: str = "./indonesian_llm_checkpoints"
# Training
num_epochs: int = 3
batch_size: int = 4
gradient_accumulation_steps: int = 12
max_seq_length: int = 1024
# Optimization
learning_rate: float = 3e-4
weight_decay: float = 0.01
adam_beta1: float = 0.9
adam_beta2: float = 0.95
adam_epsilon: float = 1e-8
max_grad_norm: float = 1.0
# Scheduler
warmup_steps: int = 100
lr_scheduler_type: str = "cosine"
# Regularization
dropout: float = 0.1
# Mixed precision
use_fp16: bool = True
# Reproducibility
seed: int = 42
# Logging
logging_steps: int = 10
eval_steps: int = 100
save_steps: int = 500
# Curriculum learning
curriculum_stages: List[int] = None
# Skip the first N curriculum stages so we don't re-train on tiny seqs.
skip_curriculum_stages: int = 2
# Patience (in eval periods) before ReduceLROnPlateau fires.
plateau_patience: int = 3
# Factor to multiply LR by when plateau is detected.
plateau_factor: float = 0.5
# Minimum improvement in perplexity to count as "not stalled".
plateau_min_delta: float = 0.02
# EWC β€” set > 0 to enable anti-forgetting penalty during finetuning
ewc_lambda: float = 0.0
ewc_samples: int = 2000 # samples used to estimate Fisher Information
def __post_init__(self):
if self.curriculum_stages is None:
self.curriculum_stages = [256, 512, 1024]
# ============================================================================
# ROTARY POSITIONAL EMBEDDINGS (RoPE)
# ============================================================================
class RotaryEmbedding(nn.Module):
"""Rotary Positional Embeddings"""
def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._set_cos_sin_cache(max_position_embeddings)
def _set_cos_sin_cache(self, seq_len: int):
self.max_seq_len_cached = seq_len
t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(self, x: torch.Tensor, seq_len: int):
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype)
)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# ============================================================================
# RMS NORMALIZATION
# ============================================================================
class RMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
# ============================================================================
# GROUPED-QUERY ATTENTION
# ============================================================================
class GroupedQueryAttention(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
assert self.hidden_size % self.num_heads == 0
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta
)
self.attention_dropout = nn.Dropout(config.attention_dropout)
def forward(self, hidden_states, attention_mask=None):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, seq_len=q_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if self.num_key_value_groups > 1:
key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = self.attention_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output
# ============================================================================
# SWIGLU FEEDFORWARD
# ============================================================================
class SwiGLUMLP(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
# ============================================================================
# DECODER LAYER
# ============================================================================
class DecoderLayer(nn.Module):
def __init__(self, config: ModelConfig, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.self_attn = GroupedQueryAttention(config)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = SwiGLUMLP(config)
self.residual_dropout = nn.Dropout(config.residual_dropout)
def forward(self, hidden_states, attention_mask=None):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)
hidden_states = self.residual_dropout(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.residual_dropout(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
# ============================================================================
# MAIN MODEL
# ============================================================================
class IndonesianLLM(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
self.layers = nn.ModuleList([DecoderLayer(config, idx) for idx in range(config.num_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if config.tie_word_embeddings:
self.lm_head = None
else:
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.apply(self._init_weights)
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def get_input_embeddings(self):
return self.embed_tokens
def _prepare_attention_mask(self, attention_mask, input_shape, dtype):
batch_size, seq_length = input_shape
causal_mask = torch.triu(
torch.ones((seq_length, seq_length), dtype=torch.bool, device=attention_mask.device),
diagonal=1
)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, seq_length, seq_length)
if attention_mask is not None:
expanded_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_length, seq_length)
expanded_mask = expanded_mask.bool()
causal_mask = causal_mask | ~expanded_mask
causal_mask = torch.where(causal_mask, torch.finfo(dtype).min, 0.0)
return causal_mask
def forward(self, input_ids, attention_mask=None, labels=None):
batch_size, seq_length = input_ids.shape
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
hidden_states = self.embed_tokens(input_ids)
attention_mask = self._prepare_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states.dtype
)
for decoder_layer in self.layers:
hidden_states = decoder_layer(hidden_states, attention_mask=attention_mask)
hidden_states = self.norm(hidden_states)
if self.lm_head is not None:
logits = self.lm_head(hidden_states)
else:
logits = F.linear(hidden_states, self.embed_tokens.weight)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(ignore_index=self.padding_idx)
loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
return {"loss": loss, "logits": logits}
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
# ============================================================================
# DATASET
# ============================================================================
class IndonesianCoTDataset(Dataset):
def __init__(
self,
file_path: str,
tokenizer,
max_length: int = 1024,
cot_token: str = "<cot>",
end_cot_token: str = "</cot>",
use_cot: bool = True,
cot_ratio: float = 0.7
):
self.tokenizer = tokenizer
self.max_length = max_length
self.cot_token = cot_token
self.end_cot_token = end_cot_token
self.use_cot = use_cot
self.cot_ratio = cot_ratio
self.samples = []
self.skipped_count = 0
self._load_data(file_path)
def _load_data(self, file_path: str):
print(f"Loading dataset from {file_path}...")
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
try:
if not line.strip():
continue
data = json.loads(line)
if not all(key in data for key in ['input', 'cot', 'output']):
self.skipped_count += 1
print(f"Warning: Line {line_num} missing required fields, skipping...")
continue
if not all(isinstance(data[key], str) for key in ['input', 'cot', 'output']):
self.skipped_count += 1
print(f"Warning: Line {line_num} has invalid data types, skipping...")
continue
if not all(data[key].strip() for key in ['input', 'cot', 'output']):
self.skipped_count += 1
print(f"Warning: Line {line_num} has empty fields, skipping...")
continue
self.samples.append(data)
except json.JSONDecodeError as e:
self.skipped_count += 1
print(f"Warning: Line {line_num} is not valid JSON ({e}), skipping...")
continue
except Exception as e:
self.skipped_count += 1
print(f"Warning: Line {line_num} caused error ({e}), skipping...")
continue
print(f"Loaded {len(self.samples)} valid samples")
print(f"Skipped {self.skipped_count} malformed rows")
if len(self.samples) == 0:
raise ValueError("No valid samples loaded from dataset!")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
if self.use_cot:
if random.random() < self.cot_ratio:
text = f"{sample['input']} {self.cot_token} {sample['cot']} {self.end_cot_token} {sample['output']}"
else:
text = f"{sample['input']} {sample['output']}"
else:
text = f"{sample['input']} {sample['output']}"
encoding = self.tokenizer(
text,
max_length=self.max_length,
truncation=True,
padding=False,
return_tensors=None
)
input_ids = encoding['input_ids']
return {
'input_ids': torch.tensor(input_ids, dtype=torch.long),
'length': len(input_ids),
'cot_length': len(self.tokenizer.encode(sample['cot'], add_special_tokens=False)) if self.use_cot else 0
}
def collate_fn_with_packing(batch, pad_token_id=0):
batch = sorted(batch, key=lambda x: x['length'], reverse=True)
max_length = max(item['length'] for item in batch)
input_ids_batch = []
attention_mask_batch = []
for item in batch:
input_ids = item['input_ids']
length = item['length']
padding_length = max_length - length
input_ids_padded = F.pad(input_ids, (0, padding_length), value=pad_token_id)
attention_mask = torch.cat([
torch.ones(length, dtype=torch.long),
torch.zeros(padding_length, dtype=torch.long)
])
input_ids_batch.append(input_ids_padded)
attention_mask_batch.append(attention_mask)
return {
'input_ids': torch.stack(input_ids_batch),
'attention_mask': torch.stack(attention_mask_batch),
'labels': torch.stack(input_ids_batch)
}
# ============================================================================
# CURRICULUM LEARNING
# ============================================================================
def create_curriculum_datasets(dataset, stages=[256, 512, 1024], use_simple=False, skip_stages=0):
"""
Build per-stage datasets.
skip_stages: skip the first N short-sequence stages (continue-train only).
When use_simple=True and skip_stages >= len(stages)-1, trains on full
dataset at max length only β€” which is exactly what you want for continuation.
"""
datasets = []
if use_simple:
for i, max_len in enumerate(stages):
filtered_samples = [
s for s in dataset.samples
if len(dataset.tokenizer.encode(
f"{s['input']} {dataset.cot_token} {s['cot']} {dataset.end_cot_token} {s['output']}"
)) <= max_len
]
stage_dataset = _build_stage_dataset(dataset, filtered_samples, max_len, dataset.cot_ratio)
datasets.append(stage_dataset)
skipped = i < skip_stages
print(f"{'[SKIP] ' if skipped else ''}Curriculum stage {max_len}: {len(filtered_samples)} samples")
else:
print("\n" + "="*80)
print("3-STAGE REASONING CURRICULUM")
if skip_stages > 0:
print(f" (Skipping first {skip_stages} stage(s) β€” continue-train mode)")
print("="*80)
stage_configs = [
{
'name': 'Stage 1: Basic Q&A (short, no reasoning)',
'max_len': 384,
'cot_ratio': 0.0,
'filter': lambda s: len(dataset.tokenizer.encode(
f"{s['input']} {s['output']}"
)) <= 384
},
{
'name': 'Stage 2: Learning Reasoning (medium, 50% CoT)',
'max_len': 512,
'cot_ratio': 0.5,
'filter': lambda s: True
},
{
'name': 'Stage 3: Full Reasoning (all, 100% CoT)',
'max_len': 1024,
'cot_ratio': 1.0,
'filter': lambda s: True
}
]
for idx, stage_config in enumerate(stage_configs):
filtered_samples = [s for s in dataset.samples if stage_config['filter'](s)]
stage_dataset = _build_stage_dataset(
dataset, filtered_samples,
stage_config['max_len'], stage_config['cot_ratio']
)
datasets.append(stage_dataset)
skipped = idx < skip_stages
prefix = " [SKIP] " if skipped else " "
print(f"{prefix}{stage_config['name']}")
print(f" {'(skipped)' if skipped else ''} Samples: {len(filtered_samples)}")
print(f" {'(skipped)' if skipped else ''} Max length: {stage_config['max_len']}")
print(f" {'(skipped)' if skipped else ''} CoT ratio: {stage_config['cot_ratio']*100:.0f}%")
print("="*80 + "\n")
if skip_stages > 0:
datasets = datasets[skip_stages:]
return datasets
def _build_stage_dataset(base_dataset, samples, max_len, cot_ratio):
"""Helper: create a shallow-copy stage dataset from a list of samples."""
stage = IndonesianCoTDataset.__new__(IndonesianCoTDataset)
stage.tokenizer = base_dataset.tokenizer
stage.max_length = max_len
stage.cot_token = base_dataset.cot_token
stage.end_cot_token = base_dataset.end_cot_token
stage.samples = samples
stage.skipped_count = 0
stage.use_cot = base_dataset.use_cot
stage.cot_ratio = cot_ratio
return stage
# ============================================================================
# LEARNING-RATE SCHEDULERS
# ============================================================================
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(0.0, float(num_training_steps - current_step) /
float(max(1, num_training_steps - num_warmup_steps)))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
def get_continue_schedule(optimizer, num_training_steps: int, min_fraction: float = 0.1):
"""
Schedule for --continue-train without saved optimizer state.
WHY NOT rewarm: AdamW bias-correction makes the *effective* first step
equal to the full LR regardless of the scheduled value (v starts at 0,
so v_hat = gΒ² after correction β†’ update β‰ˆ LRΒ·sign(g)). A warmup phase
therefore does nothing useful for step 1 and only delays actual training.
Instead: start immediately at the target LR (which is already tiny β€” see
effective_lr below) and decay gently to min_fraction Γ— LR by the end.
This is the safest profile for a pretrained model with a cold optimizer.
"""
def lr_lambda(step):
progress = float(step) / float(max(1, num_training_steps))
cosine_val = 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0)))
return min_fraction + (1.0 - min_fraction) * cosine_val
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
class PlateauLRGuard:
"""
Wraps any LambdaLR scheduler and applies an extra multiplicative penalty
when perplexity has not improved for `patience` consecutive checks.
Usage:
guard = PlateauLRGuard(scheduler, patience=3, factor=0.5, min_delta=0.02)
guard.step(current_perplexity) # call after every eval period
"""
def __init__(self, scheduler, patience=3, factor=0.5, min_delta=0.02):
self.scheduler = scheduler
self.patience = patience
self.factor = factor
self.min_delta = min_delta
self._best = float('inf')
self._no_improve = 0
self._penalty = 1.0 # cumulative multiplicative penalty
def step(self, perplexity: float):
"""Call after every evaluation. Returns True if penalty was applied."""
relative_improvement = (self._best - perplexity) / max(self._best, 1e-8)
if relative_improvement > self.min_delta:
self._best = perplexity
self._no_improve = 0
else:
self._no_improve += 1
if self._no_improve >= self.patience:
self._penalty *= self.factor
self._no_improve = 0
new_lr = self.scheduler.get_last_lr()[0] * self.factor
print(f"\n[PlateauLRGuard] No improvement for {self.patience} checks. "
f"Reducing LR by {self.factor:.2f}Γ— β†’ {new_lr:.2e}")
# Apply the multiplier directly to the base LRs in the optimizer
for pg in self.scheduler.optimizer.param_groups:
pg['lr'] *= self.factor
return True
return False
def get_penalty(self):
return self._penalty
# ============================================================================
# TRAINING UTILITIES
# ============================================================================
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# ============================================================================
# ELASTIC WEIGHT CONSOLIDATION (EWC)
# ============================================================================
class EWC:
"""
Elastic Weight Consolidation β€” prevents catastrophic forgetting during finetuning.
How it works:
After training on the original dataset, we compute the Fisher Information
Matrix (diagonal approximation) for each parameter. The Fisher value for
a parameter measures how much the loss changes when that parameter is nudged
β€” i.e. how "important" it is for the old task.
During finetuning on new data, an EWC penalty term is added to the loss:
ewc_loss = (ewc_lambda / 2) * sum_i [ F_i * (theta_i - theta_i*)^2 ]
Parameters that were important (high F_i) are penalised heavily for
moving away from their old values (theta_i*). Parameters that weren't
important can move freely to learn the new task.
Usage:
ewc = EWC(model, old_dataloader, device, n_samples=2000)
# then inside training loop:
loss = task_loss + ewc.penalty(model)
"""
def __init__(self, model, dataloader, device, n_samples: int = 2000):
self.device = device
self.n_samples = n_samples
# Store a copy of the old parameters
self.params = {
n: p.clone().detach()
for n, p in model.named_parameters()
if p.requires_grad
}
# Compute diagonal Fisher
self.fisher = self._compute_fisher(model, dataloader)
def _compute_fisher(self, model, dataloader):
fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
model.eval()
seen = 0
for batch in dataloader:
if seen >= self.n_samples:
break
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
labels = batch["labels"].to(self.device)
model.zero_grad()
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
outputs["loss"].backward()
for n, p in model.named_parameters():
if p.requires_grad and p.grad is not None:
fisher[n] += p.grad.detach().pow(2)
seen += input_ids.size(0)
# Normalise by number of samples
for n in fisher:
fisher[n] /= max(1, seen)
model.train()
return fisher
def penalty(self, model) -> torch.Tensor:
"""Return EWC penalty term to add to the task loss."""
loss = torch.tensor(0.0, device=self.device)
for n, p in model.named_parameters():
if p.requires_grad and n in self.fisher:
loss += (self.fisher[n] * (p - self.params[n]).pow(2)).sum()
return loss * 0.5
# ============================================================================
# TRAINING LOOP
# ============================================================================
def train_model(
model: IndonesianLLM,
train_dataset: IndonesianCoTDataset,
config: TrainingConfig,
device: torch.device,
use_simple_curriculum: bool = False,
is_continue: bool = False,
skip_curriculum_stages: int = 0,
ewc: "EWC | None" = None,
):
"""Main training loop."""
print("\n" + "="*80)
print("TRAINING CONFIGURATION" + (" [CONTINUE MODE]" if is_continue else ""))
print("="*80)
print(f"Model parameters: {model.count_parameters():,}")
print(f"Dataset size: {len(train_dataset)}")
print(f"Batch size: {config.batch_size}")
print(f"Gradient accumulation steps: {config.gradient_accumulation_steps}")
print(f"Effective batch size: {config.batch_size * config.gradient_accumulation_steps}")
print(f"Learning rate: {config.learning_rate}")
print(f"Max sequence length: {config.max_seq_length}")
print(f"Number of epochs: {config.num_epochs}")
print(f"Mixed precision: {config.use_fp16}")
if is_continue:
print(f"Skipping curriculum stages: {skip_curriculum_stages}")
print(f"Plateau patience: {config.plateau_patience}")
print(f"Plateau LR factor: {config.plateau_factor}")
if ewc is not None:
print(f"EWC lambda: {config.ewc_lambda} (anti-forgetting active)")
print("="*80 + "\n")
model.to(device)
model.train()
curriculum_datasets = create_curriculum_datasets(
train_dataset,
config.curriculum_stages,
use_simple=use_simple_curriculum,
skip_stages=skip_curriculum_stages,
)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.learning_rate,
betas=(config.adam_beta1, config.adam_beta2),
eps=config.adam_epsilon,
weight_decay=config.weight_decay
)
# -----------------------------------------------------------------------
# Calculate total steps
# -----------------------------------------------------------------------
total_steps = 0
for ds in curriculum_datasets:
steps_per_epoch = max(1, len(ds) // (config.batch_size * config.gradient_accumulation_steps))
total_steps += steps_per_epoch * config.num_epochs
if total_steps == 0:
total_steps = 1 # safety guard
# -----------------------------------------------------------------------
# LR Scheduler
# -----------------------------------------------------------------------
if is_continue:
scheduler = get_continue_schedule(optimizer, num_training_steps=total_steps)
plateau_guard = PlateauLRGuard(
scheduler,
patience=config.plateau_patience,
factor=config.plateau_factor,
min_delta=config.plateau_min_delta,
)
print(f"[Scheduler] Continue-train: flat cosine decay from {config.learning_rate:.2e}")
else:
if config.lr_scheduler_type == "cosine":
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=config.warmup_steps,
num_training_steps=total_steps
)
else:
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=config.warmup_steps,
num_training_steps=total_steps
)
plateau_guard = None
# -----------------------------------------------------------------------
# Mixed precision scaler
# -----------------------------------------------------------------------
scaler = torch.cuda.amp.GradScaler() if config.use_fp16 and torch.cuda.is_available() else None
# -----------------------------------------------------------------------
# Training loop
# -----------------------------------------------------------------------
global_step = 0
perplexity_history = []
for stage_idx, stage_dataset in enumerate(curriculum_datasets):
actual_stage = stage_idx + skip_curriculum_stages
print(f"\n{'='*80}")
print(f"CURRICULUM STAGE {actual_stage + 1}/{len(curriculum_datasets) + skip_curriculum_stages} "
f"(running {stage_idx + 1} of {len(curriculum_datasets)})")
print(f"Max sequence length: {stage_dataset.max_length}")
print(f"Samples: {len(stage_dataset)}")
if hasattr(stage_dataset, 'cot_ratio'):
print(f"CoT ratio: {stage_dataset.cot_ratio:.0%}")
print(f"{'='*80}\n")
dataloader = DataLoader(
stage_dataset,
batch_size=config.batch_size,
shuffle=True,
collate_fn=lambda x: collate_fn_with_packing(x, pad_token_id=model.padding_idx),
num_workers=0,
pin_memory=True if torch.cuda.is_available() else False
)
for epoch in range(config.num_epochs):
print(f"\nEpoch {epoch + 1}/{config.num_epochs}")
epoch_loss = 0.0
optimizer.zero_grad()
for step, batch in enumerate(dataloader):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
if scaler is not None:
with torch.cuda.amp.autocast():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
task_loss = outputs['loss']
if ewc is not None:
task_loss = task_loss + config.ewc_lambda * ewc.penalty(model)
loss = task_loss / config.gradient_accumulation_steps
scaler.scale(loss).backward()
else:
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
task_loss = outputs['loss']
if ewc is not None:
task_loss = task_loss + config.ewc_lambda * ewc.penalty(model)
loss = task_loss / config.gradient_accumulation_steps
loss.backward()
epoch_loss += loss.item() * config.gradient_accumulation_steps
if (step + 1) % config.gradient_accumulation_steps == 0:
if scaler is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
if scaler is not None:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
global_step += 1
if global_step % config.logging_steps == 0:
avg_loss = epoch_loss / (step + 1)
current_lr = scheduler.get_last_lr()[0]
print(f"Step {global_step:>6} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e}")
avg_epoch_loss = epoch_loss / max(1, len(dataloader))
perplexity = math.exp(min(avg_epoch_loss, 20)) # cap to avoid overflow
perplexity_history.append(perplexity)
print(f"Epoch {epoch + 1} completed | Avg Loss: {avg_epoch_loss:.4f} "
f"| Perplexity: {perplexity:.2f}")
# Plateau check for continue-train
if plateau_guard is not None:
plateau_guard.step(perplexity)
print("\n" + "="*80)
print("TRAINING COMPLETED")
if perplexity_history:
print(f"Final perplexity: {perplexity_history[-1]:.2f}")
print(f"Best perplexity: {min(perplexity_history):.2f}")
print("="*80 + "\n")
# Return optimizer and scaler states so callers can persist them
return model
# ============================================================================
# EVALUATION
# ============================================================================
def evaluate_model(model, dataset, device, batch_size=4):
model.eval()
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=lambda x: collate_fn_with_packing(x, pad_token_id=model.padding_idx),
num_workers=0
)
total_loss = 0.0
total_samples = 0
with torch.no_grad():
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
total_loss += outputs['loss'].item() * input_ids.size(0)
total_samples += input_ids.size(0)
avg_loss = total_loss / max(1, total_samples)
perplexity = math.exp(min(avg_loss, 20))
print(f"\nEvaluation Results:")
print(f"Average Loss: {avg_loss:.4f}")
print(f"Perplexity: {perplexity:.2f}")
# Thresholds calibrated for small models (15-30M params) on Indonesian CoT.
# Large model benchmarks (GPT-2 etc) don't apply here β€” a tiny model doing
# multi-step reasoning in Indonesian will naturally sit at higher perplexity.
if perplexity < 5.0:
print("Status: Excellent")
elif perplexity < 10.0:
print("Status: Good")
elif perplexity < 20.0:
print("Status: Fair β€” try more epochs or lower LR")
else:
print("Status: Poor β€” check data quality or model config")
if avg_loss < 0.5:
print("Warning: Very low loss might indicate overfitting")
return {"loss": avg_loss, "perplexity": perplexity}
# ============================================================================
# GENERATION
# ============================================================================
def generate_text(
model,
tokenizer,
prompt: str,
max_new_tokens: int = 256,
temperature: float = 0.7,
top_k: int = 50,
top_p: float = 0.9,
device: torch.device = torch.device('cpu')
):
model.eval()
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
generated_ids = input_ids.clone()
eos_token_id = tokenizer.eos_token_id
if eos_token_id is None:
eos_token_id = tokenizer.sep_token_id
if eos_token_id is None:
eos_token_id = 2
stop_tokens = {eos_token_id, tokenizer.pad_token_id}
if tokenizer.sep_token_id is not None:
stop_tokens.add(tokenizer.sep_token_id)
repetition_buffer = []
with torch.no_grad():
for step in range(max_new_tokens):
outputs = model(input_ids=generated_ids)
logits = outputs['logits']
next_token_logits = logits[:, -1, :] / max(temperature, 0.1)
if len(repetition_buffer) > 10:
for token in set(repetition_buffer[-10:]):
if token in stop_tokens:
continue
next_token_logits[0, token] -= 2.0
if top_k > 0:
indices_to_remove = next_token_logits < torch.topk(
next_token_logits, min(top_k, next_token_logits.size(-1))
)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
if next_token.item() in stop_tokens:
break
repetition_buffer.append(next_token.item())
if len(repetition_buffer) > 20:
repetition_buffer.pop(0)
generated_ids = torch.cat([generated_ids, next_token], dim=1)
# Sliding window: trim to max_position_embeddings so RoPE never
# goes out of bounds on long conversations.
max_ctx = model.config.max_position_embeddings
if generated_ids.size(1) > max_ctx:
generated_ids = generated_ids[:, -max_ctx:]
if step > 10:
decoded = tokenizer.decode(
generated_ids[0][input_ids.size(1):], skip_special_tokens=False)
if '\n\n' in decoded or 'User:' in decoded or 'Assistant:' in decoded[-20:]:
break
# Decode keeping <cot></cot> tags so caller can split on them
return tokenizer.decode(generated_ids[0], skip_special_tokens=False)
# ============================================================================
# INTERACTIVE CHAT
# ============================================================================
def _clean_response(response: str) -> str:
"""
Strip CoT block and training-format artifacts, return only the final answer.
Assumes generate_text() decoded with skip_special_tokens=False so tags visible.
"""
import re
# If full CoT block present, keep only the part after </cot>
if "<cot>" in response and "</cot>" in response:
response = response.split("</cot>", 1)[-1]
# If CoT opened but never closed, discard from <cot> onward
elif "<cot>" in response:
response = response.split("<cot>", 1)[0]
# Strip any remaining special/xml tokens
response = re.sub(r'\[\w+\]', '', response)
response = re.sub(r'<[^>]+>', '', response)
# Cut at hard training-format markers β€” these always signal garbage
for marker in [
"user :", "user:", "User :", "User:",
"assistant :", "assistant:", "Assistant :", "Assistant:",
"memahami permintaan", "jawaban singkat", "penjelasan harus",
"\n\n",
]:
if marker in response:
response = response.split(marker)[0]
# Strip leading punctuation/whitespace artifacts
response = re.sub(r'^[\s:!,\.\-|\[\]]+', '', response)
response = re.sub(r' {2,}', ' ', response).strip()
return response
def _extract_thinking(raw: str) -> tuple:
"""
Given the raw text generated AFTER the <cot> prompt token,
return (thinking_text, answer_text).
The model output looks like: "...reasoning... </cot> ...answer..."
Both parts get cleaned separately.
"""
import re
# Strip special tokens
raw = re.sub(r'\[\w+\]', '', raw)
if "</cot>" in raw:
thinking_raw, answer_raw = raw.split("</cot>", 1)
else:
# No closing tag β€” treat everything as reasoning, answer is empty
thinking_raw, answer_raw = raw, ""
# Clean thinking: only strip hard markers, keep the reasoning prose
thinking = thinking_raw.strip()
for marker in ["user :", "user:", "memahami permintaan", "\n\n"]:
if marker in thinking:
thinking = thinking.split(marker)[0]
thinking = re.sub(r'<[^>]+>', '', thinking).strip()
# Clean answer normally
answer = _clean_response(answer_raw)
return thinking, answer
def interactive_chat(model, tokenizer, device,
system_prompt: str = "Kamu adalah asisten AI yang membantu, ramah, dan menjawab dalam Bahasa Indonesia."):
"""
Chat using the same bare format the model was trained on: just the input text.
The system prompt is NOT prepended to every turn (the model never saw that
format during training and will hallucinate it). Instead it is used only
as an internal reminder of what clean answers look like β€” its effect comes
through temperature and sampling, not prompt injection.
"""
print("\n" + "="*80)
print("INDONESIAN LLM β€” INTERACTIVE CHAT")
print("="*80)
print("Commands: 'exit'/'quit' | 'clear' | 'think' (toggle reasoning display)")
print(f"Persona : {system_prompt}")
print("="*80 + "\n")
model.eval()
show_thinking = False
# Re-seed per session so responses vary across runs
import time
torch.manual_seed(int(time.time()) % 100000)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(int(time.time()) % 100000)
while True:
try:
user_input = input("\nYou: ").strip()
if not user_input:
continue
if user_input.lower() in ['exit', 'quit', 'keluar']:
print("\nGoodbye!")
break
if user_input.lower() in ['clear', 'bersihkan']:
print("\nConversation cleared")
continue
if user_input.lower() == 'think':
show_thinking = not show_thinking
print(f"\nThinking mode: {'ON' if show_thinking else 'OFF'}")
continue
# Always use CoT format β€” the model gives far better answers with it.
# Think mode just controls whether we SHOW the reasoning or hide it.
# Without <cot>, the model generates reasoning anyway but inline with
# no closing tag, making it impossible to strip cleanly.
prompt = f"{user_input} <cot>"
max_tokens = 250
print("\nA:", end=" ", flush=True)
full_response = generate_text(
model=model, tokenizer=tokenizer, prompt=prompt,
max_new_tokens=max_tokens, temperature=0.9,
top_k=50, top_p=0.92, device=device
)
response = full_response[len(prompt):].strip()
# Always extract thinking + answer from CoT output
thinking, answer = _extract_thinking(response)
if show_thinking and thinking:
print(f"[Thinking: {thinking}]")
final = answer if answer else _clean_response(response)
if not final or len(final) < 3:
final = "Maaf, saya tidak mengerti. Bisa diulang?"
print(final)
except KeyboardInterrupt:
print("\n\nChat interrupted")
break
except Exception as e:
print(f"\nError: {e}")
# ============================================================================
# BENCHMARK
# ============================================================================
def run_benchmark(model, tokenizer, device, dataset_path: str = None, n: int = 20, verbose: bool = True):
"""
Benchmark using n random samples from the actual training dataset.
Scoring: model output is checked against the expected output using:
1. Exact substring match (output contains expected answer)
2. Token overlap score >= 0.5 (at least half the expected words appear)
This way every benchmark run tests different questions, scores are
meaningful against your actual data, and results vary across runs.
"""
import time
# ── Load samples ────────────────────────────────────────────────────────
if dataset_path is None or not os.path.exists(dataset_path):
print(f"Dataset not found at: {dataset_path}")
print("Pass --dataset path/to/your.jsonl to benchmark against your data.")
return
all_samples = []
with open(dataset_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
try:
d = json.loads(line)
if all(k in d for k in ['input', 'output']) and d['input'].strip() and d['output'].strip():
all_samples.append(d)
except Exception:
continue
if not all_samples:
print("No valid samples found in dataset.")
return
# Different random samples every run
random.seed(int(time.time()))
samples = random.sample(all_samples, min(n, len(all_samples)))
model.eval()
# Fixed seed only for generation so individual answers are reproducible
# but the sample selection above is random each run
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)
print("\n" + "="*80)
print(f"BENCHMARK ({len(samples)} random samples from dataset)")
print("="*80)
results = []
ppl_loss = 0.0
ppl_toks = 0
for i, sample in enumerate(samples):
inp = sample['input'].strip()
expected = sample['output'].strip().lower()
prompt = f"{inp} <cot>"
full = generate_text(
model=model, tokenizer=tokenizer, prompt=prompt,
max_new_tokens=150, temperature=0.3,
top_k=20, top_p=0.9, device=device
)
raw = full[len(prompt):].strip()
_, answer = _extract_thinking(raw)
answer_lower = answer.lower()
# Score 1: substring match
passed = expected in answer_lower
# Score 2: token overlap (if substring fails)
if not passed:
exp_tokens = set(expected.split())
ans_tokens = set(answer_lower.split())
if exp_tokens:
overlap = len(exp_tokens & ans_tokens) / len(exp_tokens)
passed = overlap >= 0.5
results.append(passed)
# Perplexity on input+expected
with torch.no_grad():
ids = tokenizer.encode(
f"{inp} <cot> {sample['output']}",
return_tensors="pt"
).to(device)
if ids.size(1) >= 2:
out = model(input_ids=ids, labels=ids)
toks = ids.size(1) - 1
ppl_loss += out["loss"].item() * toks
ppl_toks += toks
if verbose:
status = "PASS" if passed else "FAIL"
print(f" [{status}] {inp[:60]}")
print(f" Expected : {sample['output'][:80]}")
print(f" Got : {answer[:80] if answer else '(no answer)'}")
total_pass = sum(results)
total = len(results)
overall = total_pass / total * 100
bar = "β–ˆ" * int(overall / 10) + "β–‘" * (10 - int(overall / 10))
ppl = math.exp(min(ppl_loss / max(1, ppl_toks), 20))
print("\n" + "-"*80)
print(f" SCORE {total_pass}/{total} ({overall:.1f}%) {bar}")
print(f" PERPLEXITY {ppl:.2f} (lower = better)")
print("="*80 + "\n")
return {"score": overall, "pass": total_pass, "total": total, "perplexity": ppl}
# ============================================================================
# MODEL SAVING AND LOADING (now persists optimizer + scaler state)
# ============================================================================
def save_model(model: IndonesianLLM, config: ModelConfig, tokenizer_name: str, save_path: str, use_fp16: bool = True):
"""Save model weights + config only. One .pt file, always ~55 MB fp16."""
os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else ".", exist_ok=True)
state_dict = model.state_dict()
if use_fp16:
state_dict = {k: v.half() if v.dtype == torch.float32 else v
for k, v in state_dict.items()}
torch.save({
'model_state_dict': state_dict,
'config': config,
'tokenizer_name': tokenizer_name,
'model_params': model.count_parameters(),
'dtype': 'fp16' if use_fp16 else 'fp32',
}, save_path)
size_mb = os.path.getsize(save_path) / 1e6
print(f"\nModel saved to: {save_path} ({'fp16' if use_fp16 else 'fp32'}, {size_mb:.1f} MB)")
print(f"Parameters: {model.count_parameters():,}")
def load_model(load_path: str, device: torch.device):
"""
Load model checkpoint.
Main .pt β†’ model weights only (same size as original, ~55 MB fp16).
.optstate β†’ optimizer + scaler state, read separately when needed.
Returns a 4-tuple (model, tokenizer, config, extra).
extra keys: training_metadata (may be None).
"""
if not os.path.exists(load_path):
raise FileNotFoundError(f"Model checkpoint not found: {load_path}")
print(f"Loading model from: {load_path}")
checkpoint = torch.load(load_path, map_location=device, weights_only=False)
config = checkpoint['config']
tokenizer_name = checkpoint['tokenizer_name']
dtype = checkpoint.get('dtype', 'fp32')
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
model = IndonesianLLM(config)
state_dict = checkpoint['model_state_dict']
if dtype == 'fp16':
state_dict = {k: v.float() if v.dtype == torch.float16 else v
for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.to(device)
extra = {'training_metadata': checkpoint.get('training_metadata', None)}
# Load optimizer state from sidecar .optstate file (not the main .pt) print(f" Continue-train will work but first epoch may be slightly noisy.")
size_mb = os.path.getsize(load_path) / 1e6
print(f"Model loaded ({dtype}, {size_mb:.1f} MB) | "
f"Parameters: {checkpoint.get('model_params', model.count_parameters()):,}")
return model, tokenizer, config, extra
# ============================================================================
# MAIN
# ============================================================================
def main():
parser = argparse.ArgumentParser(
description="Indonesian Conversational LLM β€” Train, Chat, Finetune, or Continue")
parser.add_argument('--train', action='store_true', help='Train model from scratch')
parser.add_argument('--chat', action='store_true', help='Interactive chat mode')
parser.add_argument('--finetune', action='store_true', help='Fine-tune on NEW data (lr/10)')
parser.add_argument('--continue-train', action='store_true',
help='Continue training on SAME data with proper LR re-warmup')
parser.add_argument('--inspect-data', action='store_true', help='Inspect dataset quality')
parser.add_argument('--benchmark', action='store_true', help='Run benchmark suite on a saved model')
parser.add_argument('--no-eval', action='store_true', help='Skip evaluation after train/finetune/continue-train')
parser.add_argument('--grad-accum', type=int, default=None, help='Override gradient accumulation steps (default: 32 train, 32 finetune, 32 continue)')
parser.add_argument('--ewc-lambda', type=float, default=5000.0,
help='EWC penalty strength for --finetune (default 5000). Higher = less forgetting but slower to learn new data. 0 = disabled.')
parser.add_argument('--ewc-samples', type=int, default=2000,
help='Samples used to estimate Fisher Information for EWC (default 2000).')
parser.add_argument('--no-ewc', action='store_true',
help='Disable EWC during finetuning (allow forgetting old data).')
parser.add_argument('--dataset', type=str, default='indonesian_cot_dataset.jsonl')
parser.add_argument('--model', type=str, default='indonesian_llm_model.pt')
parser.add_argument('--epochs', type=int, default=5)
parser.add_argument('--batch-size', type=int, default=4)
parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--max-length', type=int, default=512,
help='Max sequence length. Your dataset max is 448 so 512 is optimal.')
parser.add_argument('--hidden-size', type=int, default=320)
parser.add_argument('--num-layers', type=int, default=16)
parser.add_argument('--num-heads', type=int, default=8)
parser.add_argument('--num-kv-heads', type=int, default=2)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--save-fp16', action='store_true', default=True)
parser.add_argument('--save-fp32', action='store_true')
parser.add_argument('--no-cot', action='store_true')
parser.add_argument('--use-cot', action='store_true', default=True)
parser.add_argument('--simple-curriculum', action='store_true')
parser.add_argument('--cot-ratio', type=float, default=1.0)
parser.add_argument(
'--system-prompt', type=str,
default='Kamu adalah asisten AI yang membantu, ramah, dan menjawab dalam Bahasa Indonesia.',
help='System prompt prepended to every chat turn.'
)
# ------------------------------------------------------------------
# NEW: continue-train quality controls
# ------------------------------------------------------------------
parser.add_argument(
'--rewarm-steps', type=int, default=150,
help='Steps to re-warm LR from rewarm-start-frac β†’ full LR (continue-train only). '
'Default 150. Increase for larger models or when previous run ended far into decay.')
parser.add_argument(
'--skip-stages', type=int, default=2,
help='Number of early curriculum stages to skip in continue-train (default 2). '
'Skip=2 means start directly at Stage 3 (full CoT, full length). '
'Use 0 to re-run all stages (not recommended).')
parser.add_argument(
'--plateau-patience', type=int, default=3,
help='Epochs without improvement before LR is halved (default 3).')
parser.add_argument(
'--no-restore-optimizer', action='store_true',
help='Do NOT restore optimizer state from checkpoint (useful for debugging).')
args = parser.parse_args()
if not any([args.train, args.chat, args.finetune, args.continue_train, args.inspect_data, args.benchmark]):
parser.print_help()
print("\nError: Specify a mode: --train, --chat, --finetune, --continue-train, or --inspect-data")
return
save_fp16 = not args.save_fp32
use_cot_training = not args.no_cot
set_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nDevice: {device}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB\n")
# ------------------------------------------------------------------
# INSPECT DATA
# ------------------------------------------------------------------
if args.inspect_data:
print("\nInspecting dataset...")
print("="*80)
tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
dataset = IndonesianCoTDataset(
file_path=args.dataset, tokenizer=tokenizer,
max_length=args.max_length, use_cot=use_cot_training, cot_ratio=args.cot_ratio
)
print(f"\nDataset Statistics:")
print(f"Total samples: {len(dataset)}")
print(f"Skipped samples: {dataset.skipped_count}")
lengths = []
cot_lengths = []
for i in range(min(len(dataset), 1000)):
sample = dataset[i]
lengths.append(sample['length'])
cot_lengths.append(sample['cot_length'])
print(f"\nSequence Length Stats:")
print(f" Min: {min(lengths)}")
print(f" Max: {max(lengths)}")
print(f" Avg: {sum(lengths)/len(lengths):.1f}")
print(f" Median: {sorted(lengths)[len(lengths)//2]}")
print(f"\nCoT Length Stats:")
print(f" Min: {min(cot_lengths)}")
print(f" Max: {max(cot_lengths)}")
print(f" Avg: {sum(cot_lengths)/len(cot_lengths):.1f}")
long_cot = sum(1 for x in cot_lengths if x > 50)
print(f" Samples with long CoT (>50 tokens): {long_cot} ({long_cot/len(cot_lengths)*100:.1f}%)")
print(f"\n{'='*80}")
print("Sample Examples (first 5):")
print("="*80)
for i in range(min(5, len(dataset.samples))):
s = dataset.samples[i]
print(f"\n--- Sample {i+1} ---")
print(f"Input: {s['input'][:100]}...")
print(f"CoT: {s['cot'][:150]}...")
print(f"Output: {s['output'][:100]}...")
print("\n" + "="*80)
print("Dataset Quality Checks:")
issues = []
for i, sample in enumerate(dataset.samples[:100]):
if len(sample['input']) < 10:
issues.append(f"Sample {i}: Input too short")
if len(sample['output']) < 10:
issues.append(f"Sample {i}: Output too short")
if len(sample['cot']) < 20:
issues.append(f"Sample {i}: CoT too short")
if sample['input'].lower() == sample['output'].lower():
issues.append(f"Sample {i}: Input == Output (copy)")
if issues:
print(f"\nFound {len(issues)} potential issues in first 100 samples:")
for issue in issues[:10]:
print(f" - {issue}")
if len(issues) > 10:
print(f" ... and {len(issues)-10} more")
else:
print("\nNo obvious issues detected in first 100 samples")
print("\n" + "="*80)
return
# ------------------------------------------------------------------
# CHAT
# ------------------------------------------------------------------
if args.chat:
print("\nStarting CHAT mode...")
if not os.path.exists(args.model):
print(f"Error: Model checkpoint not found: {args.model}")
return
model, tokenizer, config, _ = load_model(args.model, device)
interactive_chat(model, tokenizer, device, system_prompt=args.system_prompt)
return
if args.benchmark:
print("\nRunning benchmark...")
if not os.path.exists(args.model):
print(f"Error: Model not found: {args.model}")
return
model, tokenizer, config, _ = load_model(args.model, device)
run_benchmark(model, tokenizer, device, dataset_path=args.dataset)
return
# ------------------------------------------------------------------
# TRAIN FROM SCRATCH
# ------------------------------------------------------------------
if args.train:
print("\nStarting TRAINING mode (from scratch)...")
tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
model_config = ModelConfig(
vocab_size=len(tokenizer),
hidden_size=args.hidden_size,
num_layers=args.num_layers,
num_attention_heads=args.num_heads,
num_key_value_heads=args.num_kv_heads,
intermediate_size=args.hidden_size * 3,
max_position_embeddings=2048,
attention_dropout=0.1,
residual_dropout=0.1,
tie_word_embeddings=True
)
model = IndonesianLLM(model_config)
print(f"Model parameters: {model.count_parameters():,}")
_ga = args.grad_accum if args.grad_accum else 32
train_config = TrainingConfig(
dataset_path=args.dataset,
num_epochs=args.epochs,
batch_size=args.batch_size,
gradient_accumulation_steps=_ga,
max_seq_length=args.max_length,
learning_rate=args.lr,
warmup_steps=500, # 297k samples β€” longer warmup stabilises early training
use_fp16=torch.cuda.is_available(),
curriculum_stages=[128, 256, args.max_length] # matches your data (avg 163, max 448)
)
dataset = IndonesianCoTDataset(
file_path=train_config.dataset_path, tokenizer=tokenizer,
max_length=train_config.max_seq_length, use_cot=use_cot_training, cot_ratio=args.cot_ratio
)
model = train_model(
model, dataset, train_config, device,
use_simple_curriculum=args.simple_curriculum
)
if not args.no_eval:
evaluate_model(model, dataset, device)
save_model(model, model_config, "indolem/indobert-base-uncased", args.model, use_fp16=save_fp16)
test_prompts = [
"Berapa hasil dari 1+1?",
"Jelaskan cara kerja komputer",
"Bagaimana cara membuat kopi yang enak?"
]
print("\n" + "="*80)
print("GENERATION TEST")
print("="*80 + "\n")
for prompt in test_prompts:
print(f"Prompt: {prompt}")
generated = generate_text(model, tokenizer, prompt, max_new_tokens=150, device=device)
print(f"Generated: {generated}\n")
print("-" * 80 + "\n")
print(f"\nTo chat: python {__file__} --chat --model {args.model}")
# ------------------------------------------------------------------
# FINETUNE (new data, lr/10)
# ------------------------------------------------------------------
if args.finetune:
print("\nStarting FINETUNE mode (for NEW data)...")
if not os.path.exists(args.model):
print(f"Error: Model checkpoint not found: {args.model}")
return
model, tokenizer, model_config, extra = load_model(args.model, device)
_ga = args.grad_accum if args.grad_accum else 32
train_config = TrainingConfig(
dataset_path=args.dataset,
num_epochs=args.epochs,
batch_size=args.batch_size,
gradient_accumulation_steps=_ga,
max_seq_length=args.max_length,
learning_rate=args.lr / 10,
warmup_steps=100,
use_fp16=torch.cuda.is_available(),
curriculum_stages=[128, 256, args.max_length]
)
dataset = IndonesianCoTDataset(
file_path=train_config.dataset_path, tokenizer=tokenizer,
max_length=train_config.max_seq_length, use_cot=use_cot_training, cot_ratio=args.cot_ratio
)
print(f"\nStarting fine-tuning with LR={train_config.learning_rate:.2e}...")
# ── EWC: compute Fisher on old data before training on new data ──────
ewc_obj = None
if not args.no_ewc and args.ewc_lambda > 0:
print(f"\nComputing EWC Fisher Information (lambda={args.ewc_lambda}, samples={args.ewc_samples})...")
print(" This takes ~1-2 min on T4. Prevents forgetting old training data.")
old_dataset = IndonesianCoTDataset(
file_path=args.dataset, tokenizer=tokenizer,
max_length=train_config.max_seq_length, use_cot=use_cot_training,
cot_ratio=args.cot_ratio
)
old_loader = DataLoader(
old_dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=lambda x: collate_fn_with_packing(x, pad_token_id=model.padding_idx),
num_workers=0
)
train_config.ewc_lambda = args.ewc_lambda
train_config.ewc_samples = args.ewc_samples
ewc_obj = EWC(model, old_loader, device, n_samples=args.ewc_samples)
print(f" Fisher computed. EWC penalty will be added during training.")
else:
print(" EWC disabled β€” model may forget previous training (use --ewc-lambda to enable).")
model = train_model(
model, dataset, train_config, device,
use_simple_curriculum=args.simple_curriculum,
ewc=ewc_obj,
)
if not args.no_eval:
evaluate_model(model, dataset, device)
finetuned_path = args.model.replace('.pt', '_finetuned.pt')
save_model(model, model_config, "indolem/indobert-base-uncased", finetuned_path, use_fp16=save_fp16)
print(f"\nFine-tuning completed. Model saved to: {finetuned_path}")
print(f"To chat: python {__file__} --chat --model {finetuned_path}")
# ------------------------------------------------------------------
# CONTINUE TRAINING (improved)
# ------------------------------------------------------------------
if args.continue_train:
print("\nStarting CONTINUE-TRAIN mode (improved)...")
print("="*80)
print("NOTE: For a 15-30M param model on Indonesian CoT, perplexity 5-15 is normal.")
print(" The old thresholds were calibrated for 1B+ models β€” ignore 'Poor' labels.")
print("Key improvements over the original continue-train:")
print(" 1. Optimizer state restored β†’ smooth Adam updates from step 1")
print(" 2. Re-warmup cosine schedule β†’ no destructive LR spike at restart")
print(" 3. Early stages skipped β†’ no wasted time on short sequences")
print(" 4. Plateau LR reduction β†’ auto-halves LR if perplexity stalls")
print("="*80)
if not os.path.exists(args.model):
print(f"Error: Model checkpoint not found: {args.model}")
return
model, tokenizer, model_config, extra = load_model(args.model, device)
# Cold-Adam reality: without saved optimizer state, AdamW's bias
# correction makes the effective first step = full LR regardless of
# any warmup schedule (v starts at 0). The only safe move is to set
# a genuinely micro LR so even that cold spike is harmless.
# Default: 5% of the training LR. Pass --lr 1e-5 to override.
effective_lr = args.lr * 0.05
print(f"\nContinue-train LR: {args.lr:.2e} Γ— 0.05 = {effective_lr:.2e}")
print(f" Adam cold-starts β€” micro LR prevents overshooting the trained minimum.")
print(f" Override with --lr 1e-5 (or any explicit value).")
# Dataset-tuned curriculum: avg=163, max=448 β†’ stages match real distribution
# Simple: [192, 320, 512] | 3-stage: same lengths, different CoT ratios
# Continue always skips short stages β€” model already learned those patterns
curriculum_stages = [192, 320, args.max_length]
if args.simple_curriculum:
effective_skip = len(curriculum_stages) - 1 # go straight to full-length
else:
effective_skip = args.skip_stages
print(f"Skipping first {effective_skip} curriculum stage(s) β€” training on full-length data only.")
_ga = args.grad_accum if args.grad_accum else 32
# steps_per_epoch β‰ˆ 297773 / (batch * accum) = 297773 / (4*32) β‰ˆ 2326
# warmup 200 = ~9% of first epoch β€” enough to stabilise cold Adam
train_config = TrainingConfig(
dataset_path=args.dataset,
num_epochs=args.epochs,
batch_size=args.batch_size,
gradient_accumulation_steps=_ga,
max_seq_length=args.max_length,
learning_rate=effective_lr,
warmup_steps=0,
use_fp16=torch.cuda.is_available(),
curriculum_stages=curriculum_stages,
skip_curriculum_stages=effective_skip,
plateau_patience=2, # 297k samples = strong signal per epoch, react faster
plateau_factor=0.5,
plateau_min_delta=0.02,
)
dataset = IndonesianCoTDataset(
file_path=train_config.dataset_path, tokenizer=tokenizer,
max_length=train_config.max_seq_length, use_cot=use_cot_training,
cot_ratio=args.cot_ratio
)
model = train_model(
model, dataset, train_config, device,
use_simple_curriculum=args.simple_curriculum,
is_continue=True,
skip_curriculum_stages=effective_skip,
)
if not args.no_eval:
print("\nEvaluating continued model...")
evaluate_model(model, dataset, device)
save_model(model, model_config, "indolem/indobert-base-uncased", args.model, use_fp16=save_fp16)
print(f"\nContinued training completed.")
print(f"Model saved to: {args.model}")
print(f"To chat: python {__file__} --chat --model {args.model}")
if __name__ == "__main__":
main()