| |
| import os |
| |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| import sys |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import TensorDataset, DataLoader |
| from transformers import AutoModelForCausalLM, get_linear_schedule_with_warmup |
| from peft import PeftModel |
| from torch.cuda.amp import GradScaler, autocast |
| from tqdm.auto import tqdm |
| from multiprocessing import freeze_support |
|
|
| def main(): |
| |
| PRET_FILE = "pretokenized_queries.pt" |
| MODEL_NAME = "google/gemma-3-1b-pt" |
| LORA_DIR = "phase2_triplet_amp/final" |
| BATCH_SIZE = 64 |
| LR = 1e-5 |
| WEIGHT_DECAY = 0.01 |
| NUM_EPOCHS = 1 |
| TEMP = 0.05 |
| OUTPUT_DIR = "phase3_self_contrast" |
| GRAD_CLIP_NORM = 1.0 |
| SEED = 42 |
|
|
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| torch.manual_seed(SEED) |
|
|
| |
| data = torch.load(PRET_FILE, weights_only=True) |
| input_ids = data["input_ids"] |
| attention_mask = data["attention_mask"] |
| dataset = TensorDataset(input_ids, attention_mask) |
| loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) |
|
|
| |
| base = AutoModelForCausalLM.from_pretrained(MODEL_NAME, attn_implementation="eager") |
| peft = PeftModel.from_pretrained(base, LORA_DIR).to(device) |
|
|
| |
| class GemmaSelfContrast(nn.Module): |
| def __init__(self, peft_model): |
| super().__init__() |
| self.peft = peft_model |
| hs = peft_model.base_model.config.hidden_size |
| self.proj = nn.Sequential( |
| nn.Linear(hs, 512), |
| nn.ReLU(), |
| nn.Linear(512, 256), |
| ) |
| def forward(self, ids, mask): |
| out = self.peft.base_model( |
| input_ids=ids, |
| attention_mask=mask, |
| output_hidden_states=True, |
| return_dict=True |
| ) |
| h = out.hidden_states[-1].mean(dim=1) |
| h = torch.nan_to_num(h, nan=0.0, posinf=1e-6, neginf=-1e-6) |
| z = self.proj(h) |
| z = torch.nan_to_num(z, nan=0.0, posinf=1e-6, neginf=-1e-6) |
| norm = z.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6) |
| return z / norm |
|
|
| model = GemmaSelfContrast(peft).to(device) |
|
|
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY) |
| total_steps = len(loader) * NUM_EPOCHS |
| scheduler = get_linear_schedule_with_warmup( |
| optimizer, |
| num_warmup_steps=int(0.1 * total_steps), |
| num_training_steps=total_steps |
| ) |
| scaler = GradScaler() |
|
|
| |
| model.train() |
| for epoch in range(1, NUM_EPOCHS + 1): |
| total_loss = 0.0 |
| for ids, mask in tqdm(loader, desc=f"Epoch {epoch}", unit="batch"): |
| ids, mask = ids.to(device), mask.to(device) |
|
|
| with autocast(): |
| e1 = model(ids, mask) |
| e2 = model(ids, mask) |
| emb = torch.cat([e1, e2], dim=0) |
| sim = (emb @ emb.T) / TEMP |
| |
| mask_eye = torch.eye(sim.size(0), device=device, dtype=torch.bool) |
| sim = sim.masked_fill(mask_eye, float('-inf')) |
|
|
| B = e1.size(0) |
| labels = torch.cat([ |
| torch.arange(B, device=device) + B, |
| torch.arange(B, device=device) |
| ], dim=0) |
| loss = F.cross_entropy(sim, labels) |
|
|
| optimizer.zero_grad() |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM) |
| scaler.step(optimizer) |
| scaler.update() |
| scheduler.step() |
|
|
| total_loss += loss.item() |
|
|
| avg_loss = total_loss / len(loader) |
| print(f"Epoch {epoch} avg loss: {avg_loss:.6f}") |
|
|
| |
| final_dir = os.path.join(OUTPUT_DIR, "final") |
| os.makedirs(final_dir, exist_ok=True) |
| peft.save_pretrained(final_dir) |
| print("Phase 3 complete. LoRA adapters saved to", final_dir) |
|
|
| if __name__ == "__main__": |
| freeze_support() |
| main() |
| sys.exit(0) |
|
|