| """
|
| Production-Grade Indonesian Conversational Language Model
|
| Trained from scratch with Chain-of-Thought reasoning capability
|
|
|
| Architecture: Decoder-only transformer with MQA/GQA, RoPE, SwiGLU, RMSNorm
|
| Target: 15M-30M parameters, optimized for Google Colab Free tier
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from torch.utils.data import Dataset, DataLoader
|
| from transformers import AutoTokenizer
|
| import json
|
| import math
|
| import random
|
| import numpy as np
|
| from typing import Optional, Tuple, List, Dict
|
| from dataclasses import dataclass
|
| import warnings
|
| import argparse
|
| import os
|
| warnings.filterwarnings('ignore')
|
|
|
|
|
|
|
|
|
|
|
|
|
| @dataclass
|
| class ModelConfig:
|
| """Model architecture configuration"""
|
| vocab_size: int = 30000
|
| hidden_size: int = 384
|
| num_layers: int = 12
|
| num_attention_heads: int = 6
|
| num_key_value_heads: int = 2
|
| intermediate_size: int = 1024
|
| max_position_embeddings: int = 2048
|
| rms_norm_eps: float = 1e-6
|
| rope_theta: float = 10000.0
|
| attention_dropout: float = 0.1
|
| residual_dropout: float = 0.1
|
| initializer_range: float = 0.02
|
| use_cache: bool = False
|
| pad_token_id: int = 0
|
| bos_token_id: int = 1
|
| eos_token_id: int = 2
|
| tie_word_embeddings: bool = True
|
|
|
| def __post_init__(self):
|
| assert self.hidden_size % self.num_attention_heads == 0
|
| assert self.num_attention_heads % self.num_key_value_heads == 0
|
|
|
|
|
| @dataclass
|
| class TrainingConfig:
|
| """Training hyperparameters"""
|
| dataset_path: str = "indonesian_cot_dataset.jsonl"
|
| output_dir: str = "./indonesian_llm_checkpoints"
|
|
|
|
|
| num_epochs: int = 3
|
| batch_size: int = 4
|
| gradient_accumulation_steps: int = 12
|
| max_seq_length: int = 1024
|
|
|
|
|
| learning_rate: float = 3e-4
|
| weight_decay: float = 0.01
|
| adam_beta1: float = 0.9
|
| adam_beta2: float = 0.95
|
| adam_epsilon: float = 1e-8
|
| max_grad_norm: float = 1.0
|
|
|
|
|
| warmup_steps: int = 100
|
| lr_scheduler_type: str = "cosine"
|
|
|
|
|
| dropout: float = 0.1
|
|
|
|
|
| use_fp16: bool = True
|
|
|
|
|
| seed: int = 42
|
|
|
|
|
| logging_steps: int = 10
|
| eval_steps: int = 100
|
| save_steps: int = 500
|
|
|
|
|
| curriculum_stages: List[int] = None
|
|
|
|
|
| skip_curriculum_stages: int = 2
|
|
|
| plateau_patience: int = 3
|
|
|
| plateau_factor: float = 0.5
|
|
|
| plateau_min_delta: float = 0.02
|
|
|
|
|
| ewc_lambda: float = 0.0
|
| ewc_samples: int = 2000
|
|
|
| def __post_init__(self):
|
| if self.curriculum_stages is None:
|
| self.curriculum_stages = [256, 512, 1024]
|
|
|
|
|
|
|
|
|
|
|
|
|
| class RotaryEmbedding(nn.Module):
|
| """Rotary Positional Embeddings"""
|
|
|
| def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000.0):
|
| super().__init__()
|
| self.dim = dim
|
| self.max_position_embeddings = max_position_embeddings
|
| self.base = base
|
|
|
| inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
| self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
|
| self._set_cos_sin_cache(max_position_embeddings)
|
|
|
| def _set_cos_sin_cache(self, seq_len: int):
|
| self.max_seq_len_cached = seq_len
|
| t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
|
| freqs = torch.outer(t, self.inv_freq)
|
| emb = torch.cat((freqs, freqs), dim=-1)
|
| self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
| self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
|
|
| def forward(self, x: torch.Tensor, seq_len: int):
|
| if seq_len > self.max_seq_len_cached:
|
| self._set_cos_sin_cache(seq_len)
|
|
|
| return (
|
| self.cos_cached[:seq_len].to(dtype=x.dtype),
|
| self.sin_cached[:seq_len].to(dtype=x.dtype)
|
| )
|
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| x1 = x[..., : x.shape[-1] // 2]
|
| x2 = x[..., x.shape[-1] // 2:]
|
| return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin):
|
| q_embed = (q * cos) + (rotate_half(q) * sin)
|
| k_embed = (k * cos) + (rotate_half(k) * sin)
|
| return q_embed, k_embed
|
|
|
|
|
|
|
|
|
|
|
|
|
| class RMSNorm(nn.Module):
|
| def __init__(self, hidden_size: int, eps: float = 1e-6):
|
| super().__init__()
|
| self.weight = nn.Parameter(torch.ones(hidden_size))
|
| self.variance_epsilon = eps
|
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| input_dtype = hidden_states.dtype
|
| hidden_states = hidden_states.to(torch.float32)
|
| variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| return self.weight * hidden_states.to(input_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
| class GroupedQueryAttention(nn.Module):
|
| def __init__(self, config: ModelConfig):
|
| super().__init__()
|
| self.config = config
|
| self.hidden_size = config.hidden_size
|
| self.num_heads = config.num_attention_heads
|
| self.num_key_value_heads = config.num_key_value_heads
|
| self.head_dim = self.hidden_size // self.num_heads
|
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| self.max_position_embeddings = config.max_position_embeddings
|
| self.rope_theta = config.rope_theta
|
|
|
| assert self.hidden_size % self.num_heads == 0
|
|
|
| self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
|
|
| self.rotary_emb = RotaryEmbedding(
|
| self.head_dim,
|
| max_position_embeddings=self.max_position_embeddings,
|
| base=self.rope_theta
|
| )
|
| self.attention_dropout = nn.Dropout(config.attention_dropout)
|
|
|
| def forward(self, hidden_states, attention_mask=None):
|
| bsz, q_len, _ = hidden_states.size()
|
|
|
| query_states = self.q_proj(hidden_states)
|
| key_states = self.k_proj(hidden_states)
|
| value_states = self.v_proj(hidden_states)
|
|
|
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
| cos, sin = self.rotary_emb(value_states, seq_len=q_len)
|
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
| if self.num_key_value_groups > 1:
|
| key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
|
| value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
|
|
|
| attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
|
| if attention_mask is not None:
|
| attn_weights = attn_weights + attention_mask
|
|
|
| attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| attn_weights = self.attention_dropout(attn_weights)
|
|
|
| attn_output = torch.matmul(attn_weights, value_states)
|
| attn_output = attn_output.transpose(1, 2).contiguous()
|
| attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| attn_output = self.o_proj(attn_output)
|
|
|
| return attn_output
|
|
|
|
|
|
|
|
|
|
|
|
|
| class SwiGLUMLP(nn.Module):
|
| def __init__(self, config: ModelConfig):
|
| super().__init__()
|
| self.hidden_size = config.hidden_size
|
| self.intermediate_size = config.intermediate_size
|
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
|
| def forward(self, x):
|
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
|
|
|
|
|
|
|
|
|
|
|
| class DecoderLayer(nn.Module):
|
| def __init__(self, config: ModelConfig, layer_idx: int):
|
| super().__init__()
|
| self.layer_idx = layer_idx
|
| self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| self.self_attn = GroupedQueryAttention(config)
|
| self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| self.mlp = SwiGLUMLP(config)
|
| self.residual_dropout = nn.Dropout(config.residual_dropout)
|
|
|
| def forward(self, hidden_states, attention_mask=None):
|
| residual = hidden_states
|
| hidden_states = self.input_layernorm(hidden_states)
|
| hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)
|
| hidden_states = self.residual_dropout(hidden_states)
|
| hidden_states = residual + hidden_states
|
|
|
| residual = hidden_states
|
| hidden_states = self.post_attention_layernorm(hidden_states)
|
| hidden_states = self.mlp(hidden_states)
|
| hidden_states = self.residual_dropout(hidden_states)
|
| hidden_states = residual + hidden_states
|
|
|
| return hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
|
| class IndonesianLLM(nn.Module):
|
| def __init__(self, config: ModelConfig):
|
| super().__init__()
|
| self.config = config
|
| self.padding_idx = config.pad_token_id
|
| self.vocab_size = config.vocab_size
|
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
|
| self.layers = nn.ModuleList([DecoderLayer(config, idx) for idx in range(config.num_layers)])
|
| self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
| if config.tie_word_embeddings:
|
| self.lm_head = None
|
| else:
|
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
| self.apply(self._init_weights)
|
|
|
| def _init_weights(self, module):
|
| std = self.config.initializer_range
|
| if isinstance(module, nn.Linear):
|
| module.weight.data.normal_(mean=0.0, std=std)
|
| if module.bias is not None:
|
| module.bias.data.zero_()
|
| elif isinstance(module, nn.Embedding):
|
| module.weight.data.normal_(mean=0.0, std=std)
|
| if module.padding_idx is not None:
|
| module.weight.data[module.padding_idx].zero_()
|
|
|
| def get_input_embeddings(self):
|
| return self.embed_tokens
|
|
|
| def _prepare_attention_mask(self, attention_mask, input_shape, dtype):
|
| batch_size, seq_length = input_shape
|
| causal_mask = torch.triu(
|
| torch.ones((seq_length, seq_length), dtype=torch.bool, device=attention_mask.device),
|
| diagonal=1
|
| )
|
| causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, seq_length, seq_length)
|
|
|
| if attention_mask is not None:
|
| expanded_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_length, seq_length)
|
| expanded_mask = expanded_mask.bool()
|
| causal_mask = causal_mask | ~expanded_mask
|
|
|
| causal_mask = torch.where(causal_mask, torch.finfo(dtype).min, 0.0)
|
| return causal_mask
|
|
|
| def forward(self, input_ids, attention_mask=None, labels=None):
|
| batch_size, seq_length = input_ids.shape
|
|
|
| if attention_mask is None:
|
| attention_mask = torch.ones_like(input_ids)
|
|
|
| hidden_states = self.embed_tokens(input_ids)
|
| attention_mask = self._prepare_attention_mask(
|
| attention_mask, (batch_size, seq_length), hidden_states.dtype
|
| )
|
|
|
| for decoder_layer in self.layers:
|
| hidden_states = decoder_layer(hidden_states, attention_mask=attention_mask)
|
|
|
| hidden_states = self.norm(hidden_states)
|
|
|
| if self.lm_head is not None:
|
| logits = self.lm_head(hidden_states)
|
| else:
|
| logits = F.linear(hidden_states, self.embed_tokens.weight)
|
|
|
| loss = None
|
| if labels is not None:
|
| shift_logits = logits[..., :-1, :].contiguous()
|
| shift_labels = labels[..., 1:].contiguous()
|
| loss_fct = nn.CrossEntropyLoss(ignore_index=self.padding_idx)
|
| loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
|
|
|
| return {"loss": loss, "logits": logits}
|
|
|
| def count_parameters(self):
|
| return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
|
|
|
|
|
|
|
|
|
|
|
| class IndonesianCoTDataset(Dataset):
|
|
|
| def __init__(
|
| self,
|
| file_path: str,
|
| tokenizer,
|
| max_length: int = 1024,
|
| cot_token: str = "<cot>",
|
| end_cot_token: str = "</cot>",
|
| use_cot: bool = True,
|
| cot_ratio: float = 0.7
|
| ):
|
| self.tokenizer = tokenizer
|
| self.max_length = max_length
|
| self.cot_token = cot_token
|
| self.end_cot_token = end_cot_token
|
| self.use_cot = use_cot
|
| self.cot_ratio = cot_ratio
|
| self.samples = []
|
| self.skipped_count = 0
|
|
|
| self._load_data(file_path)
|
|
|
| def _load_data(self, file_path: str):
|
| print(f"Loading dataset from {file_path}...")
|
|
|
| with open(file_path, 'r', encoding='utf-8') as f:
|
| for line_num, line in enumerate(f, 1):
|
| try:
|
| if not line.strip():
|
| continue
|
| data = json.loads(line)
|
|
|
| if not all(key in data for key in ['input', 'cot', 'output']):
|
| self.skipped_count += 1
|
| print(f"Warning: Line {line_num} missing required fields, skipping...")
|
| continue
|
|
|
| if not all(isinstance(data[key], str) for key in ['input', 'cot', 'output']):
|
| self.skipped_count += 1
|
| print(f"Warning: Line {line_num} has invalid data types, skipping...")
|
| continue
|
|
|
| if not all(data[key].strip() for key in ['input', 'cot', 'output']):
|
| self.skipped_count += 1
|
| print(f"Warning: Line {line_num} has empty fields, skipping...")
|
| continue
|
|
|
| self.samples.append(data)
|
|
|
| except json.JSONDecodeError as e:
|
| self.skipped_count += 1
|
| print(f"Warning: Line {line_num} is not valid JSON ({e}), skipping...")
|
| continue
|
| except Exception as e:
|
| self.skipped_count += 1
|
| print(f"Warning: Line {line_num} caused error ({e}), skipping...")
|
| continue
|
|
|
| print(f"Loaded {len(self.samples)} valid samples")
|
| print(f"Skipped {self.skipped_count} malformed rows")
|
|
|
| if len(self.samples) == 0:
|
| raise ValueError("No valid samples loaded from dataset!")
|
|
|
| def __len__(self):
|
| return len(self.samples)
|
|
|
| def __getitem__(self, idx):
|
| sample = self.samples[idx]
|
|
|
| if self.use_cot:
|
| if random.random() < self.cot_ratio:
|
| text = f"{sample['input']} {self.cot_token} {sample['cot']} {self.end_cot_token} {sample['output']}"
|
| else:
|
| text = f"{sample['input']} {sample['output']}"
|
| else:
|
| text = f"{sample['input']} {sample['output']}"
|
|
|
| encoding = self.tokenizer(
|
| text,
|
| max_length=self.max_length,
|
| truncation=True,
|
| padding=False,
|
| return_tensors=None
|
| )
|
| input_ids = encoding['input_ids']
|
|
|
| return {
|
| 'input_ids': torch.tensor(input_ids, dtype=torch.long),
|
| 'length': len(input_ids),
|
| 'cot_length': len(self.tokenizer.encode(sample['cot'], add_special_tokens=False)) if self.use_cot else 0
|
| }
|
|
|
|
|
| def collate_fn_with_packing(batch, pad_token_id=0):
|
| batch = sorted(batch, key=lambda x: x['length'], reverse=True)
|
| max_length = max(item['length'] for item in batch)
|
|
|
| input_ids_batch = []
|
| attention_mask_batch = []
|
|
|
| for item in batch:
|
| input_ids = item['input_ids']
|
| length = item['length']
|
| padding_length = max_length - length
|
| input_ids_padded = F.pad(input_ids, (0, padding_length), value=pad_token_id)
|
| attention_mask = torch.cat([
|
| torch.ones(length, dtype=torch.long),
|
| torch.zeros(padding_length, dtype=torch.long)
|
| ])
|
| input_ids_batch.append(input_ids_padded)
|
| attention_mask_batch.append(attention_mask)
|
|
|
| return {
|
| 'input_ids': torch.stack(input_ids_batch),
|
| 'attention_mask': torch.stack(attention_mask_batch),
|
| 'labels': torch.stack(input_ids_batch)
|
| }
|
|
|
|
|
|
|
|
|
|
|
|
|
| def create_curriculum_datasets(dataset, stages=[256, 512, 1024], use_simple=False, skip_stages=0):
|
| """
|
| Build per-stage datasets.
|
| skip_stages: skip the first N short-sequence stages (continue-train only).
|
| When use_simple=True and skip_stages >= len(stages)-1, trains on full
|
| dataset at max length only β which is exactly what you want for continuation.
|
| """
|
| datasets = []
|
|
|
| if use_simple:
|
| for i, max_len in enumerate(stages):
|
| filtered_samples = [
|
| s for s in dataset.samples
|
| if len(dataset.tokenizer.encode(
|
| f"{s['input']} {dataset.cot_token} {s['cot']} {dataset.end_cot_token} {s['output']}"
|
| )) <= max_len
|
| ]
|
| stage_dataset = _build_stage_dataset(dataset, filtered_samples, max_len, dataset.cot_ratio)
|
| datasets.append(stage_dataset)
|
| skipped = i < skip_stages
|
| print(f"{'[SKIP] ' if skipped else ''}Curriculum stage {max_len}: {len(filtered_samples)} samples")
|
| else:
|
| print("\n" + "="*80)
|
| print("3-STAGE REASONING CURRICULUM")
|
| if skip_stages > 0:
|
| print(f" (Skipping first {skip_stages} stage(s) β continue-train mode)")
|
| print("="*80)
|
|
|
| stage_configs = [
|
| {
|
| 'name': 'Stage 1: Basic Q&A (short, no reasoning)',
|
| 'max_len': 384,
|
| 'cot_ratio': 0.0,
|
| 'filter': lambda s: len(dataset.tokenizer.encode(
|
| f"{s['input']} {s['output']}"
|
| )) <= 384
|
| },
|
| {
|
| 'name': 'Stage 2: Learning Reasoning (medium, 50% CoT)',
|
| 'max_len': 512,
|
| 'cot_ratio': 0.5,
|
| 'filter': lambda s: True
|
| },
|
| {
|
| 'name': 'Stage 3: Full Reasoning (all, 100% CoT)',
|
| 'max_len': 1024,
|
| 'cot_ratio': 1.0,
|
| 'filter': lambda s: True
|
| }
|
| ]
|
|
|
| for idx, stage_config in enumerate(stage_configs):
|
| filtered_samples = [s for s in dataset.samples if stage_config['filter'](s)]
|
| stage_dataset = _build_stage_dataset(
|
| dataset, filtered_samples,
|
| stage_config['max_len'], stage_config['cot_ratio']
|
| )
|
| datasets.append(stage_dataset)
|
|
|
| skipped = idx < skip_stages
|
| prefix = " [SKIP] " if skipped else " "
|
| print(f"{prefix}{stage_config['name']}")
|
| print(f" {'(skipped)' if skipped else ''} Samples: {len(filtered_samples)}")
|
| print(f" {'(skipped)' if skipped else ''} Max length: {stage_config['max_len']}")
|
| print(f" {'(skipped)' if skipped else ''} CoT ratio: {stage_config['cot_ratio']*100:.0f}%")
|
|
|
| print("="*80 + "\n")
|
|
|
| if skip_stages > 0:
|
| datasets = datasets[skip_stages:]
|
|
|
| return datasets
|
|
|
|
|
| def _build_stage_dataset(base_dataset, samples, max_len, cot_ratio):
|
| """Helper: create a shallow-copy stage dataset from a list of samples."""
|
| stage = IndonesianCoTDataset.__new__(IndonesianCoTDataset)
|
| stage.tokenizer = base_dataset.tokenizer
|
| stage.max_length = max_len
|
| stage.cot_token = base_dataset.cot_token
|
| stage.end_cot_token = base_dataset.end_cot_token
|
| stage.samples = samples
|
| stage.skipped_count = 0
|
| stage.use_cot = base_dataset.use_cot
|
| stage.cot_ratio = cot_ratio
|
| return stage
|
|
|
|
|
|
|
|
|
|
|
|
|
| def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
|
| def lr_lambda(current_step):
|
| if current_step < num_warmup_steps:
|
| return float(current_step) / float(max(1, num_warmup_steps))
|
| return max(0.0, float(num_training_steps - current_step) /
|
| float(max(1, num_training_steps - num_warmup_steps)))
|
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
|
|
|
|
| def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
|
| def lr_lambda(current_step):
|
| if current_step < num_warmup_steps:
|
| return float(current_step) / float(max(1, num_warmup_steps))
|
| progress = float(current_step - num_warmup_steps) / float(
|
| max(1, num_training_steps - num_warmup_steps))
|
| return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
|
|
|
|
| def get_continue_schedule(optimizer, num_training_steps: int, min_fraction: float = 0.1):
|
| """
|
| Schedule for --continue-train without saved optimizer state.
|
|
|
| WHY NOT rewarm: AdamW bias-correction makes the *effective* first step
|
| equal to the full LR regardless of the scheduled value (v starts at 0,
|
| so v_hat = gΒ² after correction β update β LRΒ·sign(g)). A warmup phase
|
| therefore does nothing useful for step 1 and only delays actual training.
|
|
|
| Instead: start immediately at the target LR (which is already tiny β see
|
| effective_lr below) and decay gently to min_fraction Γ LR by the end.
|
| This is the safest profile for a pretrained model with a cold optimizer.
|
| """
|
| def lr_lambda(step):
|
| progress = float(step) / float(max(1, num_training_steps))
|
| cosine_val = 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0)))
|
| return min_fraction + (1.0 - min_fraction) * cosine_val
|
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
|
|
|
|
| class PlateauLRGuard:
|
| """
|
| Wraps any LambdaLR scheduler and applies an extra multiplicative penalty
|
| when perplexity has not improved for `patience` consecutive checks.
|
|
|
| Usage:
|
| guard = PlateauLRGuard(scheduler, patience=3, factor=0.5, min_delta=0.02)
|
| guard.step(current_perplexity) # call after every eval period
|
| """
|
|
|
| def __init__(self, scheduler, patience=3, factor=0.5, min_delta=0.02):
|
| self.scheduler = scheduler
|
| self.patience = patience
|
| self.factor = factor
|
| self.min_delta = min_delta
|
| self._best = float('inf')
|
| self._no_improve = 0
|
| self._penalty = 1.0
|
|
|
| def step(self, perplexity: float):
|
| """Call after every evaluation. Returns True if penalty was applied."""
|
| relative_improvement = (self._best - perplexity) / max(self._best, 1e-8)
|
|
|
| if relative_improvement > self.min_delta:
|
| self._best = perplexity
|
| self._no_improve = 0
|
| else:
|
| self._no_improve += 1
|
|
|
| if self._no_improve >= self.patience:
|
| self._penalty *= self.factor
|
| self._no_improve = 0
|
| new_lr = self.scheduler.get_last_lr()[0] * self.factor
|
| print(f"\n[PlateauLRGuard] No improvement for {self.patience} checks. "
|
| f"Reducing LR by {self.factor:.2f}Γ β {new_lr:.2e}")
|
|
|
| for pg in self.scheduler.optimizer.param_groups:
|
| pg['lr'] *= self.factor
|
| return True
|
|
|
| return False
|
|
|
| def get_penalty(self):
|
| return self._penalty
|
|
|
|
|
|
|
|
|
|
|
|
|
| def set_seed(seed: int):
|
| random.seed(seed)
|
| np.random.seed(seed)
|
| torch.manual_seed(seed)
|
| torch.cuda.manual_seed_all(seed)
|
| torch.backends.cudnn.deterministic = True
|
| torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
|
|
|
|
|
|
|
| class EWC:
|
| """
|
| Elastic Weight Consolidation β prevents catastrophic forgetting during finetuning.
|
|
|
| How it works:
|
| After training on the original dataset, we compute the Fisher Information
|
| Matrix (diagonal approximation) for each parameter. The Fisher value for
|
| a parameter measures how much the loss changes when that parameter is nudged
|
| β i.e. how "important" it is for the old task.
|
|
|
| During finetuning on new data, an EWC penalty term is added to the loss:
|
| ewc_loss = (ewc_lambda / 2) * sum_i [ F_i * (theta_i - theta_i*)^2 ]
|
|
|
| Parameters that were important (high F_i) are penalised heavily for
|
| moving away from their old values (theta_i*). Parameters that weren't
|
| important can move freely to learn the new task.
|
|
|
| Usage:
|
| ewc = EWC(model, old_dataloader, device, n_samples=2000)
|
| # then inside training loop:
|
| loss = task_loss + ewc.penalty(model)
|
| """
|
|
|
| def __init__(self, model, dataloader, device, n_samples: int = 2000):
|
| self.device = device
|
| self.n_samples = n_samples
|
|
|
|
|
| self.params = {
|
| n: p.clone().detach()
|
| for n, p in model.named_parameters()
|
| if p.requires_grad
|
| }
|
|
|
|
|
| self.fisher = self._compute_fisher(model, dataloader)
|
|
|
| def _compute_fisher(self, model, dataloader):
|
| fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
|
|
|
| model.eval()
|
| seen = 0
|
|
|
| for batch in dataloader:
|
| if seen >= self.n_samples:
|
| break
|
|
|
| input_ids = batch["input_ids"].to(self.device)
|
| attention_mask = batch["attention_mask"].to(self.device)
|
| labels = batch["labels"].to(self.device)
|
|
|
| model.zero_grad()
|
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| outputs["loss"].backward()
|
|
|
| for n, p in model.named_parameters():
|
| if p.requires_grad and p.grad is not None:
|
| fisher[n] += p.grad.detach().pow(2)
|
|
|
| seen += input_ids.size(0)
|
|
|
|
|
| for n in fisher:
|
| fisher[n] /= max(1, seen)
|
|
|
| model.train()
|
| return fisher
|
|
|
| def penalty(self, model) -> torch.Tensor:
|
| """Return EWC penalty term to add to the task loss."""
|
| loss = torch.tensor(0.0, device=self.device)
|
| for n, p in model.named_parameters():
|
| if p.requires_grad and n in self.fisher:
|
| loss += (self.fisher[n] * (p - self.params[n]).pow(2)).sum()
|
| return loss * 0.5
|
|
|
|
|
|
|
|
|
|
|
|
|
| def train_model(
|
| model: IndonesianLLM,
|
| train_dataset: IndonesianCoTDataset,
|
| config: TrainingConfig,
|
| device: torch.device,
|
| use_simple_curriculum: bool = False,
|
| is_continue: bool = False,
|
| skip_curriculum_stages: int = 0,
|
| ewc: "EWC | None" = None,
|
| ):
|
| """Main training loop."""
|
|
|
| print("\n" + "="*80)
|
| print("TRAINING CONFIGURATION" + (" [CONTINUE MODE]" if is_continue else ""))
|
| print("="*80)
|
| print(f"Model parameters: {model.count_parameters():,}")
|
| print(f"Dataset size: {len(train_dataset)}")
|
| print(f"Batch size: {config.batch_size}")
|
| print(f"Gradient accumulation steps: {config.gradient_accumulation_steps}")
|
| print(f"Effective batch size: {config.batch_size * config.gradient_accumulation_steps}")
|
| print(f"Learning rate: {config.learning_rate}")
|
| print(f"Max sequence length: {config.max_seq_length}")
|
| print(f"Number of epochs: {config.num_epochs}")
|
| print(f"Mixed precision: {config.use_fp16}")
|
| if is_continue:
|
| print(f"Skipping curriculum stages: {skip_curriculum_stages}")
|
| print(f"Plateau patience: {config.plateau_patience}")
|
| print(f"Plateau LR factor: {config.plateau_factor}")
|
| if ewc is not None:
|
| print(f"EWC lambda: {config.ewc_lambda} (anti-forgetting active)")
|
| print("="*80 + "\n")
|
|
|
| model.to(device)
|
| model.train()
|
|
|
| curriculum_datasets = create_curriculum_datasets(
|
| train_dataset,
|
| config.curriculum_stages,
|
| use_simple=use_simple_curriculum,
|
| skip_stages=skip_curriculum_stages,
|
| )
|
|
|
| optimizer = torch.optim.AdamW(
|
| model.parameters(),
|
| lr=config.learning_rate,
|
| betas=(config.adam_beta1, config.adam_beta2),
|
| eps=config.adam_epsilon,
|
| weight_decay=config.weight_decay
|
| )
|
|
|
|
|
|
|
|
|
| total_steps = 0
|
| for ds in curriculum_datasets:
|
| steps_per_epoch = max(1, len(ds) // (config.batch_size * config.gradient_accumulation_steps))
|
| total_steps += steps_per_epoch * config.num_epochs
|
|
|
| if total_steps == 0:
|
| total_steps = 1
|
|
|
|
|
|
|
|
|
| if is_continue:
|
| scheduler = get_continue_schedule(optimizer, num_training_steps=total_steps)
|
| plateau_guard = PlateauLRGuard(
|
| scheduler,
|
| patience=config.plateau_patience,
|
| factor=config.plateau_factor,
|
| min_delta=config.plateau_min_delta,
|
| )
|
| print(f"[Scheduler] Continue-train: flat cosine decay from {config.learning_rate:.2e}")
|
| else:
|
| if config.lr_scheduler_type == "cosine":
|
| scheduler = get_cosine_schedule_with_warmup(
|
| optimizer,
|
| num_warmup_steps=config.warmup_steps,
|
| num_training_steps=total_steps
|
| )
|
| else:
|
| scheduler = get_linear_schedule_with_warmup(
|
| optimizer,
|
| num_warmup_steps=config.warmup_steps,
|
| num_training_steps=total_steps
|
| )
|
| plateau_guard = None
|
|
|
|
|
|
|
|
|
| scaler = torch.cuda.amp.GradScaler() if config.use_fp16 and torch.cuda.is_available() else None
|
|
|
|
|
|
|
|
|
| global_step = 0
|
| perplexity_history = []
|
|
|
| for stage_idx, stage_dataset in enumerate(curriculum_datasets):
|
| actual_stage = stage_idx + skip_curriculum_stages
|
| print(f"\n{'='*80}")
|
| print(f"CURRICULUM STAGE {actual_stage + 1}/{len(curriculum_datasets) + skip_curriculum_stages} "
|
| f"(running {stage_idx + 1} of {len(curriculum_datasets)})")
|
| print(f"Max sequence length: {stage_dataset.max_length}")
|
| print(f"Samples: {len(stage_dataset)}")
|
| if hasattr(stage_dataset, 'cot_ratio'):
|
| print(f"CoT ratio: {stage_dataset.cot_ratio:.0%}")
|
| print(f"{'='*80}\n")
|
|
|
| dataloader = DataLoader(
|
| stage_dataset,
|
| batch_size=config.batch_size,
|
| shuffle=True,
|
| collate_fn=lambda x: collate_fn_with_packing(x, pad_token_id=model.padding_idx),
|
| num_workers=0,
|
| pin_memory=True if torch.cuda.is_available() else False
|
| )
|
|
|
| for epoch in range(config.num_epochs):
|
| print(f"\nEpoch {epoch + 1}/{config.num_epochs}")
|
| epoch_loss = 0.0
|
| optimizer.zero_grad()
|
|
|
| for step, batch in enumerate(dataloader):
|
| input_ids = batch['input_ids'].to(device)
|
| attention_mask = batch['attention_mask'].to(device)
|
| labels = batch['labels'].to(device)
|
|
|
| if scaler is not None:
|
| with torch.cuda.amp.autocast():
|
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| task_loss = outputs['loss']
|
| if ewc is not None:
|
| task_loss = task_loss + config.ewc_lambda * ewc.penalty(model)
|
| loss = task_loss / config.gradient_accumulation_steps
|
| scaler.scale(loss).backward()
|
| else:
|
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| task_loss = outputs['loss']
|
| if ewc is not None:
|
| task_loss = task_loss + config.ewc_lambda * ewc.penalty(model)
|
| loss = task_loss / config.gradient_accumulation_steps
|
| loss.backward()
|
|
|
| epoch_loss += loss.item() * config.gradient_accumulation_steps
|
|
|
| if (step + 1) % config.gradient_accumulation_steps == 0:
|
| if scaler is not None:
|
| scaler.unscale_(optimizer)
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
|
|
|
| if scaler is not None:
|
| scaler.step(optimizer)
|
| scaler.update()
|
| else:
|
| optimizer.step()
|
|
|
| scheduler.step()
|
| optimizer.zero_grad()
|
| global_step += 1
|
|
|
| if global_step % config.logging_steps == 0:
|
| avg_loss = epoch_loss / (step + 1)
|
| current_lr = scheduler.get_last_lr()[0]
|
| print(f"Step {global_step:>6} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e}")
|
|
|
| avg_epoch_loss = epoch_loss / max(1, len(dataloader))
|
| perplexity = math.exp(min(avg_epoch_loss, 20))
|
| perplexity_history.append(perplexity)
|
|
|
| print(f"Epoch {epoch + 1} completed | Avg Loss: {avg_epoch_loss:.4f} "
|
| f"| Perplexity: {perplexity:.2f}")
|
|
|
|
|
| if plateau_guard is not None:
|
| plateau_guard.step(perplexity)
|
|
|
| print("\n" + "="*80)
|
| print("TRAINING COMPLETED")
|
| if perplexity_history:
|
| print(f"Final perplexity: {perplexity_history[-1]:.2f}")
|
| print(f"Best perplexity: {min(perplexity_history):.2f}")
|
| print("="*80 + "\n")
|
|
|
|
|
| return model
|
|
|
|
|
|
|
|
|
|
|
|
|
| def evaluate_model(model, dataset, device, batch_size=4):
|
| model.eval()
|
|
|
| dataloader = DataLoader(
|
| dataset,
|
| batch_size=batch_size,
|
| shuffle=False,
|
| collate_fn=lambda x: collate_fn_with_packing(x, pad_token_id=model.padding_idx),
|
| num_workers=0
|
| )
|
|
|
| total_loss = 0.0
|
| total_samples = 0
|
|
|
| with torch.no_grad():
|
| for batch in dataloader:
|
| input_ids = batch['input_ids'].to(device)
|
| attention_mask = batch['attention_mask'].to(device)
|
| labels = batch['labels'].to(device)
|
|
|
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| total_loss += outputs['loss'].item() * input_ids.size(0)
|
| total_samples += input_ids.size(0)
|
|
|
| avg_loss = total_loss / max(1, total_samples)
|
| perplexity = math.exp(min(avg_loss, 20))
|
|
|
| print(f"\nEvaluation Results:")
|
| print(f"Average Loss: {avg_loss:.4f}")
|
| print(f"Perplexity: {perplexity:.2f}")
|
|
|
|
|
|
|
|
|
| if perplexity < 5.0:
|
| print("Status: Excellent")
|
| elif perplexity < 10.0:
|
| print("Status: Good")
|
| elif perplexity < 20.0:
|
| print("Status: Fair β try more epochs or lower LR")
|
| else:
|
| print("Status: Poor β check data quality or model config")
|
|
|
| if avg_loss < 0.5:
|
| print("Warning: Very low loss might indicate overfitting")
|
|
|
| return {"loss": avg_loss, "perplexity": perplexity}
|
|
|
|
|
|
|
|
|
|
|
|
|
| def generate_text(
|
| model,
|
| tokenizer,
|
| prompt: str,
|
| max_new_tokens: int = 256,
|
| temperature: float = 0.7,
|
| top_k: int = 50,
|
| top_p: float = 0.9,
|
| device: torch.device = torch.device('cpu')
|
| ):
|
| model.eval()
|
|
|
| input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
|
| generated_ids = input_ids.clone()
|
|
|
| eos_token_id = tokenizer.eos_token_id
|
| if eos_token_id is None:
|
| eos_token_id = tokenizer.sep_token_id
|
| if eos_token_id is None:
|
| eos_token_id = 2
|
|
|
| stop_tokens = {eos_token_id, tokenizer.pad_token_id}
|
| if tokenizer.sep_token_id is not None:
|
| stop_tokens.add(tokenizer.sep_token_id)
|
|
|
| repetition_buffer = []
|
|
|
| with torch.no_grad():
|
| for step in range(max_new_tokens):
|
| outputs = model(input_ids=generated_ids)
|
| logits = outputs['logits']
|
| next_token_logits = logits[:, -1, :] / max(temperature, 0.1)
|
|
|
| if len(repetition_buffer) > 10:
|
| for token in set(repetition_buffer[-10:]):
|
| if token in stop_tokens:
|
| continue
|
| next_token_logits[0, token] -= 2.0
|
|
|
| if top_k > 0:
|
| indices_to_remove = next_token_logits < torch.topk(
|
| next_token_logits, min(top_k, next_token_logits.size(-1))
|
| )[0][..., -1, None]
|
| next_token_logits[indices_to_remove] = float('-inf')
|
|
|
| if top_p < 1.0:
|
| sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| sorted_indices_to_remove = cumulative_probs > top_p
|
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| sorted_indices_to_remove[..., 0] = 0
|
| indices_to_remove = sorted_indices_to_remove.scatter(
|
| 1, sorted_indices, sorted_indices_to_remove)
|
| next_token_logits[indices_to_remove] = float('-inf')
|
|
|
| probs = F.softmax(next_token_logits, dim=-1)
|
| next_token = torch.multinomial(probs, num_samples=1)
|
|
|
| if next_token.item() in stop_tokens:
|
| break
|
|
|
| repetition_buffer.append(next_token.item())
|
| if len(repetition_buffer) > 20:
|
| repetition_buffer.pop(0)
|
|
|
| generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
|
|
|
|
|
|
| max_ctx = model.config.max_position_embeddings
|
| if generated_ids.size(1) > max_ctx:
|
| generated_ids = generated_ids[:, -max_ctx:]
|
|
|
| if step > 10:
|
| decoded = tokenizer.decode(
|
| generated_ids[0][input_ids.size(1):], skip_special_tokens=False)
|
| if '\n\n' in decoded or 'User:' in decoded or 'Assistant:' in decoded[-20:]:
|
| break
|
|
|
|
|
| return tokenizer.decode(generated_ids[0], skip_special_tokens=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _clean_response(response: str) -> str:
|
| """
|
| Strip CoT block and training-format artifacts, return only the final answer.
|
| Assumes generate_text() decoded with skip_special_tokens=False so tags visible.
|
| """
|
| import re
|
|
|
|
|
| if "<cot>" in response and "</cot>" in response:
|
| response = response.split("</cot>", 1)[-1]
|
|
|
| elif "<cot>" in response:
|
| response = response.split("<cot>", 1)[0]
|
|
|
|
|
| response = re.sub(r'\[\w+\]', '', response)
|
| response = re.sub(r'<[^>]+>', '', response)
|
|
|
|
|
| for marker in [
|
| "user :", "user:", "User :", "User:",
|
| "assistant :", "assistant:", "Assistant :", "Assistant:",
|
| "memahami permintaan", "jawaban singkat", "penjelasan harus",
|
| "\n\n",
|
| ]:
|
| if marker in response:
|
| response = response.split(marker)[0]
|
|
|
|
|
| response = re.sub(r'^[\s:!,\.\-|\[\]]+', '', response)
|
| response = re.sub(r' {2,}', ' ', response).strip()
|
| return response
|
|
|
|
|
| def _extract_thinking(raw: str) -> tuple:
|
| """
|
| Given the raw text generated AFTER the <cot> prompt token,
|
| return (thinking_text, answer_text).
|
|
|
| The model output looks like: "...reasoning... </cot> ...answer..."
|
| Both parts get cleaned separately.
|
| """
|
| import re
|
|
|
|
|
| raw = re.sub(r'\[\w+\]', '', raw)
|
|
|
| if "</cot>" in raw:
|
| thinking_raw, answer_raw = raw.split("</cot>", 1)
|
| else:
|
|
|
| thinking_raw, answer_raw = raw, ""
|
|
|
|
|
| thinking = thinking_raw.strip()
|
| for marker in ["user :", "user:", "memahami permintaan", "\n\n"]:
|
| if marker in thinking:
|
| thinking = thinking.split(marker)[0]
|
| thinking = re.sub(r'<[^>]+>', '', thinking).strip()
|
|
|
|
|
| answer = _clean_response(answer_raw)
|
|
|
| return thinking, answer
|
|
|
|
|
| def interactive_chat(model, tokenizer, device,
|
| system_prompt: str = "Kamu adalah asisten AI yang membantu, ramah, dan menjawab dalam Bahasa Indonesia."):
|
| """
|
| Chat using the same bare format the model was trained on: just the input text.
|
| The system prompt is NOT prepended to every turn (the model never saw that
|
| format during training and will hallucinate it). Instead it is used only
|
| as an internal reminder of what clean answers look like β its effect comes
|
| through temperature and sampling, not prompt injection.
|
| """
|
| print("\n" + "="*80)
|
| print("INDONESIAN LLM β INTERACTIVE CHAT")
|
| print("="*80)
|
| print("Commands: 'exit'/'quit' | 'clear' | 'think' (toggle reasoning display)")
|
| print(f"Persona : {system_prompt}")
|
| print("="*80 + "\n")
|
|
|
| model.eval()
|
| show_thinking = False
|
|
|
|
|
| import time
|
| torch.manual_seed(int(time.time()) % 100000)
|
| if torch.cuda.is_available():
|
| torch.cuda.manual_seed_all(int(time.time()) % 100000)
|
|
|
| while True:
|
| try:
|
| user_input = input("\nYou: ").strip()
|
| if not user_input:
|
| continue
|
| if user_input.lower() in ['exit', 'quit', 'keluar']:
|
| print("\nGoodbye!")
|
| break
|
| if user_input.lower() in ['clear', 'bersihkan']:
|
| print("\nConversation cleared")
|
| continue
|
| if user_input.lower() == 'think':
|
| show_thinking = not show_thinking
|
| print(f"\nThinking mode: {'ON' if show_thinking else 'OFF'}")
|
| continue
|
|
|
|
|
|
|
|
|
|
|
| prompt = f"{user_input} <cot>"
|
| max_tokens = 250
|
|
|
| print("\nA:", end=" ", flush=True)
|
| full_response = generate_text(
|
| model=model, tokenizer=tokenizer, prompt=prompt,
|
| max_new_tokens=max_tokens, temperature=0.9,
|
| top_k=50, top_p=0.92, device=device
|
| )
|
| response = full_response[len(prompt):].strip()
|
|
|
|
|
| thinking, answer = _extract_thinking(response)
|
|
|
| if show_thinking and thinking:
|
| print(f"[Thinking: {thinking}]")
|
|
|
| final = answer if answer else _clean_response(response)
|
| if not final or len(final) < 3:
|
| final = "Maaf, saya tidak mengerti. Bisa diulang?"
|
| print(final)
|
|
|
| except KeyboardInterrupt:
|
| print("\n\nChat interrupted")
|
| break
|
| except Exception as e:
|
| print(f"\nError: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def run_benchmark(model, tokenizer, device, dataset_path: str = None, n: int = 20, verbose: bool = True):
|
| """
|
| Benchmark using n random samples from the actual training dataset.
|
|
|
| Scoring: model output is checked against the expected output using:
|
| 1. Exact substring match (output contains expected answer)
|
| 2. Token overlap score >= 0.5 (at least half the expected words appear)
|
|
|
| This way every benchmark run tests different questions, scores are
|
| meaningful against your actual data, and results vary across runs.
|
| """
|
| import time
|
|
|
|
|
| if dataset_path is None or not os.path.exists(dataset_path):
|
| print(f"Dataset not found at: {dataset_path}")
|
| print("Pass --dataset path/to/your.jsonl to benchmark against your data.")
|
| return
|
|
|
| all_samples = []
|
| with open(dataset_path, 'r', encoding='utf-8') as f:
|
| for line in f:
|
| line = line.strip()
|
| if not line:
|
| continue
|
| try:
|
| d = json.loads(line)
|
| if all(k in d for k in ['input', 'output']) and d['input'].strip() and d['output'].strip():
|
| all_samples.append(d)
|
| except Exception:
|
| continue
|
|
|
| if not all_samples:
|
| print("No valid samples found in dataset.")
|
| return
|
|
|
|
|
| random.seed(int(time.time()))
|
| samples = random.sample(all_samples, min(n, len(all_samples)))
|
|
|
| model.eval()
|
|
|
|
|
| torch.manual_seed(42)
|
| if torch.cuda.is_available():
|
| torch.cuda.manual_seed_all(42)
|
|
|
| print("\n" + "="*80)
|
| print(f"BENCHMARK ({len(samples)} random samples from dataset)")
|
| print("="*80)
|
|
|
| results = []
|
| ppl_loss = 0.0
|
| ppl_toks = 0
|
|
|
| for i, sample in enumerate(samples):
|
| inp = sample['input'].strip()
|
| expected = sample['output'].strip().lower()
|
|
|
| prompt = f"{inp} <cot>"
|
| full = generate_text(
|
| model=model, tokenizer=tokenizer, prompt=prompt,
|
| max_new_tokens=150, temperature=0.3,
|
| top_k=20, top_p=0.9, device=device
|
| )
|
| raw = full[len(prompt):].strip()
|
| _, answer = _extract_thinking(raw)
|
| answer_lower = answer.lower()
|
|
|
|
|
| passed = expected in answer_lower
|
|
|
|
|
| if not passed:
|
| exp_tokens = set(expected.split())
|
| ans_tokens = set(answer_lower.split())
|
| if exp_tokens:
|
| overlap = len(exp_tokens & ans_tokens) / len(exp_tokens)
|
| passed = overlap >= 0.5
|
|
|
| results.append(passed)
|
|
|
|
|
| with torch.no_grad():
|
| ids = tokenizer.encode(
|
| f"{inp} <cot> {sample['output']}",
|
| return_tensors="pt"
|
| ).to(device)
|
| if ids.size(1) >= 2:
|
| out = model(input_ids=ids, labels=ids)
|
| toks = ids.size(1) - 1
|
| ppl_loss += out["loss"].item() * toks
|
| ppl_toks += toks
|
|
|
| if verbose:
|
| status = "PASS" if passed else "FAIL"
|
| print(f" [{status}] {inp[:60]}")
|
| print(f" Expected : {sample['output'][:80]}")
|
| print(f" Got : {answer[:80] if answer else '(no answer)'}")
|
|
|
| total_pass = sum(results)
|
| total = len(results)
|
| overall = total_pass / total * 100
|
| bar = "β" * int(overall / 10) + "β" * (10 - int(overall / 10))
|
|
|
| ppl = math.exp(min(ppl_loss / max(1, ppl_toks), 20))
|
|
|
| print("\n" + "-"*80)
|
| print(f" SCORE {total_pass}/{total} ({overall:.1f}%) {bar}")
|
| print(f" PERPLEXITY {ppl:.2f} (lower = better)")
|
| print("="*80 + "\n")
|
|
|
| return {"score": overall, "pass": total_pass, "total": total, "perplexity": ppl}
|
|
|
|
|
|
|
|
|
|
|
| def save_model(model: IndonesianLLM, config: ModelConfig, tokenizer_name: str, save_path: str, use_fp16: bool = True):
|
| """Save model weights + config only. One .pt file, always ~55 MB fp16."""
|
| os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else ".", exist_ok=True)
|
|
|
| state_dict = model.state_dict()
|
| if use_fp16:
|
| state_dict = {k: v.half() if v.dtype == torch.float32 else v
|
| for k, v in state_dict.items()}
|
|
|
| torch.save({
|
| 'model_state_dict': state_dict,
|
| 'config': config,
|
| 'tokenizer_name': tokenizer_name,
|
| 'model_params': model.count_parameters(),
|
| 'dtype': 'fp16' if use_fp16 else 'fp32',
|
| }, save_path)
|
|
|
| size_mb = os.path.getsize(save_path) / 1e6
|
| print(f"\nModel saved to: {save_path} ({'fp16' if use_fp16 else 'fp32'}, {size_mb:.1f} MB)")
|
| print(f"Parameters: {model.count_parameters():,}")
|
|
|
|
|
| def load_model(load_path: str, device: torch.device):
|
| """
|
| Load model checkpoint.
|
|
|
| Main .pt β model weights only (same size as original, ~55 MB fp16).
|
| .optstate β optimizer + scaler state, read separately when needed.
|
|
|
| Returns a 4-tuple (model, tokenizer, config, extra).
|
| extra keys: training_metadata (may be None).
|
| """
|
| if not os.path.exists(load_path):
|
| raise FileNotFoundError(f"Model checkpoint not found: {load_path}")
|
|
|
| print(f"Loading model from: {load_path}")
|
| checkpoint = torch.load(load_path, map_location=device, weights_only=False)
|
|
|
|
|
|
|
| config = checkpoint['config']
|
| tokenizer_name = checkpoint['tokenizer_name']
|
| dtype = checkpoint.get('dtype', 'fp32')
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
|
|
|
| model = IndonesianLLM(config)
|
|
|
| state_dict = checkpoint['model_state_dict']
|
| if dtype == 'fp16':
|
| state_dict = {k: v.float() if v.dtype == torch.float16 else v
|
| for k, v in state_dict.items()}
|
|
|
| model.load_state_dict(state_dict)
|
| model.to(device)
|
|
|
| extra = {'training_metadata': checkpoint.get('training_metadata', None)}
|
|
|
|
|
|
|
| size_mb = os.path.getsize(load_path) / 1e6
|
| print(f"Model loaded ({dtype}, {size_mb:.1f} MB) | "
|
| f"Parameters: {checkpoint.get('model_params', model.count_parameters()):,}")
|
|
|
| return model, tokenizer, config, extra
|
|
|
|
|
|
|
|
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(
|
| description="Indonesian Conversational LLM β Train, Chat, Finetune, or Continue")
|
|
|
| parser.add_argument('--train', action='store_true', help='Train model from scratch')
|
| parser.add_argument('--chat', action='store_true', help='Interactive chat mode')
|
| parser.add_argument('--finetune', action='store_true', help='Fine-tune on NEW data (lr/10)')
|
| parser.add_argument('--continue-train', action='store_true',
|
| help='Continue training on SAME data with proper LR re-warmup')
|
| parser.add_argument('--inspect-data', action='store_true', help='Inspect dataset quality')
|
| parser.add_argument('--benchmark', action='store_true', help='Run benchmark suite on a saved model')
|
| parser.add_argument('--no-eval', action='store_true', help='Skip evaluation after train/finetune/continue-train')
|
| parser.add_argument('--grad-accum', type=int, default=None, help='Override gradient accumulation steps (default: 32 train, 32 finetune, 32 continue)')
|
| parser.add_argument('--ewc-lambda', type=float, default=5000.0,
|
| help='EWC penalty strength for --finetune (default 5000). Higher = less forgetting but slower to learn new data. 0 = disabled.')
|
| parser.add_argument('--ewc-samples', type=int, default=2000,
|
| help='Samples used to estimate Fisher Information for EWC (default 2000).')
|
| parser.add_argument('--no-ewc', action='store_true',
|
| help='Disable EWC during finetuning (allow forgetting old data).')
|
|
|
| parser.add_argument('--dataset', type=str, default='indonesian_cot_dataset.jsonl')
|
| parser.add_argument('--model', type=str, default='indonesian_llm_model.pt')
|
| parser.add_argument('--epochs', type=int, default=5)
|
| parser.add_argument('--batch-size', type=int, default=4)
|
| parser.add_argument('--lr', type=float, default=2e-4)
|
| parser.add_argument('--max-length', type=int, default=512,
|
| help='Max sequence length. Your dataset max is 448 so 512 is optimal.')
|
| parser.add_argument('--hidden-size', type=int, default=320)
|
| parser.add_argument('--num-layers', type=int, default=16)
|
| parser.add_argument('--num-heads', type=int, default=8)
|
| parser.add_argument('--num-kv-heads', type=int, default=2)
|
| parser.add_argument('--seed', type=int, default=42)
|
| parser.add_argument('--save-fp16', action='store_true', default=True)
|
| parser.add_argument('--save-fp32', action='store_true')
|
| parser.add_argument('--no-cot', action='store_true')
|
| parser.add_argument('--use-cot', action='store_true', default=True)
|
| parser.add_argument('--simple-curriculum', action='store_true')
|
| parser.add_argument('--cot-ratio', type=float, default=1.0)
|
| parser.add_argument(
|
| '--system-prompt', type=str,
|
| default='Kamu adalah asisten AI yang membantu, ramah, dan menjawab dalam Bahasa Indonesia.',
|
| help='System prompt prepended to every chat turn.'
|
| )
|
|
|
|
|
|
|
|
|
| parser.add_argument(
|
| '--rewarm-steps', type=int, default=150,
|
| help='Steps to re-warm LR from rewarm-start-frac β full LR (continue-train only). '
|
| 'Default 150. Increase for larger models or when previous run ended far into decay.')
|
| parser.add_argument(
|
| '--skip-stages', type=int, default=2,
|
| help='Number of early curriculum stages to skip in continue-train (default 2). '
|
| 'Skip=2 means start directly at Stage 3 (full CoT, full length). '
|
| 'Use 0 to re-run all stages (not recommended).')
|
| parser.add_argument(
|
| '--plateau-patience', type=int, default=3,
|
| help='Epochs without improvement before LR is halved (default 3).')
|
| parser.add_argument(
|
| '--no-restore-optimizer', action='store_true',
|
| help='Do NOT restore optimizer state from checkpoint (useful for debugging).')
|
|
|
| args = parser.parse_args()
|
|
|
| if not any([args.train, args.chat, args.finetune, args.continue_train, args.inspect_data, args.benchmark]):
|
| parser.print_help()
|
| print("\nError: Specify a mode: --train, --chat, --finetune, --continue-train, or --inspect-data")
|
| return
|
|
|
| save_fp16 = not args.save_fp32
|
| use_cot_training = not args.no_cot
|
|
|
| set_seed(args.seed)
|
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| print(f"\nDevice: {device}")
|
| if torch.cuda.is_available():
|
| print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB\n")
|
|
|
|
|
|
|
|
|
| if args.inspect_data:
|
| print("\nInspecting dataset...")
|
| print("="*80)
|
|
|
| tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
|
| tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
|
|
|
| dataset = IndonesianCoTDataset(
|
| file_path=args.dataset, tokenizer=tokenizer,
|
| max_length=args.max_length, use_cot=use_cot_training, cot_ratio=args.cot_ratio
|
| )
|
|
|
| print(f"\nDataset Statistics:")
|
| print(f"Total samples: {len(dataset)}")
|
| print(f"Skipped samples: {dataset.skipped_count}")
|
|
|
| lengths = []
|
| cot_lengths = []
|
| for i in range(min(len(dataset), 1000)):
|
| sample = dataset[i]
|
| lengths.append(sample['length'])
|
| cot_lengths.append(sample['cot_length'])
|
|
|
| print(f"\nSequence Length Stats:")
|
| print(f" Min: {min(lengths)}")
|
| print(f" Max: {max(lengths)}")
|
| print(f" Avg: {sum(lengths)/len(lengths):.1f}")
|
| print(f" Median: {sorted(lengths)[len(lengths)//2]}")
|
|
|
| print(f"\nCoT Length Stats:")
|
| print(f" Min: {min(cot_lengths)}")
|
| print(f" Max: {max(cot_lengths)}")
|
| print(f" Avg: {sum(cot_lengths)/len(cot_lengths):.1f}")
|
|
|
| long_cot = sum(1 for x in cot_lengths if x > 50)
|
| print(f" Samples with long CoT (>50 tokens): {long_cot} ({long_cot/len(cot_lengths)*100:.1f}%)")
|
|
|
| print(f"\n{'='*80}")
|
| print("Sample Examples (first 5):")
|
| print("="*80)
|
| for i in range(min(5, len(dataset.samples))):
|
| s = dataset.samples[i]
|
| print(f"\n--- Sample {i+1} ---")
|
| print(f"Input: {s['input'][:100]}...")
|
| print(f"CoT: {s['cot'][:150]}...")
|
| print(f"Output: {s['output'][:100]}...")
|
|
|
| print("\n" + "="*80)
|
| print("Dataset Quality Checks:")
|
| issues = []
|
| for i, sample in enumerate(dataset.samples[:100]):
|
| if len(sample['input']) < 10:
|
| issues.append(f"Sample {i}: Input too short")
|
| if len(sample['output']) < 10:
|
| issues.append(f"Sample {i}: Output too short")
|
| if len(sample['cot']) < 20:
|
| issues.append(f"Sample {i}: CoT too short")
|
| if sample['input'].lower() == sample['output'].lower():
|
| issues.append(f"Sample {i}: Input == Output (copy)")
|
|
|
| if issues:
|
| print(f"\nFound {len(issues)} potential issues in first 100 samples:")
|
| for issue in issues[:10]:
|
| print(f" - {issue}")
|
| if len(issues) > 10:
|
| print(f" ... and {len(issues)-10} more")
|
| else:
|
| print("\nNo obvious issues detected in first 100 samples")
|
| print("\n" + "="*80)
|
| return
|
|
|
|
|
|
|
|
|
| if args.chat:
|
| print("\nStarting CHAT mode...")
|
| if not os.path.exists(args.model):
|
| print(f"Error: Model checkpoint not found: {args.model}")
|
| return
|
| model, tokenizer, config, _ = load_model(args.model, device)
|
| interactive_chat(model, tokenizer, device, system_prompt=args.system_prompt)
|
| return
|
|
|
| if args.benchmark:
|
| print("\nRunning benchmark...")
|
| if not os.path.exists(args.model):
|
| print(f"Error: Model not found: {args.model}")
|
| return
|
| model, tokenizer, config, _ = load_model(args.model, device)
|
| run_benchmark(model, tokenizer, device, dataset_path=args.dataset)
|
| return
|
|
|
|
|
|
|
|
|
| if args.train:
|
| print("\nStarting TRAINING mode (from scratch)...")
|
|
|
| tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
|
| tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
|
|
|
| model_config = ModelConfig(
|
| vocab_size=len(tokenizer),
|
| hidden_size=args.hidden_size,
|
| num_layers=args.num_layers,
|
| num_attention_heads=args.num_heads,
|
| num_key_value_heads=args.num_kv_heads,
|
| intermediate_size=args.hidden_size * 3,
|
| max_position_embeddings=2048,
|
| attention_dropout=0.1,
|
| residual_dropout=0.1,
|
| tie_word_embeddings=True
|
| )
|
|
|
| model = IndonesianLLM(model_config)
|
| print(f"Model parameters: {model.count_parameters():,}")
|
|
|
| _ga = args.grad_accum if args.grad_accum else 32
|
| train_config = TrainingConfig(
|
| dataset_path=args.dataset,
|
| num_epochs=args.epochs,
|
| batch_size=args.batch_size,
|
| gradient_accumulation_steps=_ga,
|
| max_seq_length=args.max_length,
|
| learning_rate=args.lr,
|
| warmup_steps=500,
|
| use_fp16=torch.cuda.is_available(),
|
| curriculum_stages=[128, 256, args.max_length]
|
| )
|
|
|
| dataset = IndonesianCoTDataset(
|
| file_path=train_config.dataset_path, tokenizer=tokenizer,
|
| max_length=train_config.max_seq_length, use_cot=use_cot_training, cot_ratio=args.cot_ratio
|
| )
|
|
|
| model = train_model(
|
| model, dataset, train_config, device,
|
| use_simple_curriculum=args.simple_curriculum
|
| )
|
|
|
| if not args.no_eval:
|
| evaluate_model(model, dataset, device)
|
|
|
| save_model(model, model_config, "indolem/indobert-base-uncased", args.model, use_fp16=save_fp16)
|
|
|
| test_prompts = [
|
| "Berapa hasil dari 1+1?",
|
| "Jelaskan cara kerja komputer",
|
| "Bagaimana cara membuat kopi yang enak?"
|
| ]
|
| print("\n" + "="*80)
|
| print("GENERATION TEST")
|
| print("="*80 + "\n")
|
| for prompt in test_prompts:
|
| print(f"Prompt: {prompt}")
|
| generated = generate_text(model, tokenizer, prompt, max_new_tokens=150, device=device)
|
| print(f"Generated: {generated}\n")
|
| print("-" * 80 + "\n")
|
|
|
| print(f"\nTo chat: python {__file__} --chat --model {args.model}")
|
|
|
|
|
|
|
|
|
| if args.finetune:
|
| print("\nStarting FINETUNE mode (for NEW data)...")
|
|
|
| if not os.path.exists(args.model):
|
| print(f"Error: Model checkpoint not found: {args.model}")
|
| return
|
|
|
| model, tokenizer, model_config, extra = load_model(args.model, device)
|
|
|
| _ga = args.grad_accum if args.grad_accum else 32
|
| train_config = TrainingConfig(
|
| dataset_path=args.dataset,
|
| num_epochs=args.epochs,
|
| batch_size=args.batch_size,
|
| gradient_accumulation_steps=_ga,
|
| max_seq_length=args.max_length,
|
| learning_rate=args.lr / 10,
|
| warmup_steps=100,
|
| use_fp16=torch.cuda.is_available(),
|
| curriculum_stages=[128, 256, args.max_length]
|
| )
|
|
|
| dataset = IndonesianCoTDataset(
|
| file_path=train_config.dataset_path, tokenizer=tokenizer,
|
| max_length=train_config.max_seq_length, use_cot=use_cot_training, cot_ratio=args.cot_ratio
|
| )
|
|
|
| print(f"\nStarting fine-tuning with LR={train_config.learning_rate:.2e}...")
|
|
|
|
|
| ewc_obj = None
|
| if not args.no_ewc and args.ewc_lambda > 0:
|
| print(f"\nComputing EWC Fisher Information (lambda={args.ewc_lambda}, samples={args.ewc_samples})...")
|
| print(" This takes ~1-2 min on T4. Prevents forgetting old training data.")
|
| old_dataset = IndonesianCoTDataset(
|
| file_path=args.dataset, tokenizer=tokenizer,
|
| max_length=train_config.max_seq_length, use_cot=use_cot_training,
|
| cot_ratio=args.cot_ratio
|
| )
|
| old_loader = DataLoader(
|
| old_dataset,
|
| batch_size=args.batch_size,
|
| shuffle=True,
|
| collate_fn=lambda x: collate_fn_with_packing(x, pad_token_id=model.padding_idx),
|
| num_workers=0
|
| )
|
| train_config.ewc_lambda = args.ewc_lambda
|
| train_config.ewc_samples = args.ewc_samples
|
| ewc_obj = EWC(model, old_loader, device, n_samples=args.ewc_samples)
|
| print(f" Fisher computed. EWC penalty will be added during training.")
|
| else:
|
| print(" EWC disabled β model may forget previous training (use --ewc-lambda to enable).")
|
|
|
| model = train_model(
|
| model, dataset, train_config, device,
|
| use_simple_curriculum=args.simple_curriculum,
|
| ewc=ewc_obj,
|
| )
|
|
|
| if not args.no_eval:
|
| evaluate_model(model, dataset, device)
|
|
|
| finetuned_path = args.model.replace('.pt', '_finetuned.pt')
|
| save_model(model, model_config, "indolem/indobert-base-uncased", finetuned_path, use_fp16=save_fp16)
|
|
|
| print(f"\nFine-tuning completed. Model saved to: {finetuned_path}")
|
| print(f"To chat: python {__file__} --chat --model {finetuned_path}")
|
|
|
|
|
|
|
|
|
| if args.continue_train:
|
| print("\nStarting CONTINUE-TRAIN mode (improved)...")
|
| print("="*80)
|
| print("NOTE: For a 15-30M param model on Indonesian CoT, perplexity 5-15 is normal.")
|
| print(" The old thresholds were calibrated for 1B+ models β ignore 'Poor' labels.")
|
| print("Key improvements over the original continue-train:")
|
| print(" 1. Optimizer state restored β smooth Adam updates from step 1")
|
| print(" 2. Re-warmup cosine schedule β no destructive LR spike at restart")
|
| print(" 3. Early stages skipped β no wasted time on short sequences")
|
| print(" 4. Plateau LR reduction β auto-halves LR if perplexity stalls")
|
| print("="*80)
|
|
|
| if not os.path.exists(args.model):
|
| print(f"Error: Model checkpoint not found: {args.model}")
|
| return
|
|
|
| model, tokenizer, model_config, extra = load_model(args.model, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| effective_lr = args.lr * 0.05
|
| print(f"\nContinue-train LR: {args.lr:.2e} Γ 0.05 = {effective_lr:.2e}")
|
| print(f" Adam cold-starts β micro LR prevents overshooting the trained minimum.")
|
| print(f" Override with --lr 1e-5 (or any explicit value).")
|
|
|
|
|
|
|
|
|
| curriculum_stages = [192, 320, args.max_length]
|
| if args.simple_curriculum:
|
| effective_skip = len(curriculum_stages) - 1
|
| else:
|
| effective_skip = args.skip_stages
|
| print(f"Skipping first {effective_skip} curriculum stage(s) β training on full-length data only.")
|
|
|
| _ga = args.grad_accum if args.grad_accum else 32
|
|
|
|
|
| train_config = TrainingConfig(
|
| dataset_path=args.dataset,
|
| num_epochs=args.epochs,
|
| batch_size=args.batch_size,
|
| gradient_accumulation_steps=_ga,
|
| max_seq_length=args.max_length,
|
| learning_rate=effective_lr,
|
| warmup_steps=0,
|
| use_fp16=torch.cuda.is_available(),
|
| curriculum_stages=curriculum_stages,
|
| skip_curriculum_stages=effective_skip,
|
| plateau_patience=2,
|
| plateau_factor=0.5,
|
| plateau_min_delta=0.02,
|
| )
|
|
|
| dataset = IndonesianCoTDataset(
|
| file_path=train_config.dataset_path, tokenizer=tokenizer,
|
| max_length=train_config.max_seq_length, use_cot=use_cot_training,
|
| cot_ratio=args.cot_ratio
|
| )
|
|
|
| model = train_model(
|
| model, dataset, train_config, device,
|
| use_simple_curriculum=args.simple_curriculum,
|
| is_continue=True,
|
| skip_curriculum_stages=effective_skip,
|
| )
|
|
|
| if not args.no_eval:
|
| print("\nEvaluating continued model...")
|
| evaluate_model(model, dataset, device)
|
|
|
| save_model(model, model_config, "indolem/indobert-base-uncased", args.model, use_fp16=save_fp16)
|
|
|
| print(f"\nContinued training completed.")
|
| print(f"Model saved to: {args.model}")
|
| print(f"To chat: python {__file__} --chat --model {args.model}")
|
|
|
|
|
| if __name__ == "__main__":
|
| main() |