| |
| import os |
| os.environ["CUDA_LAUNCH_BLOCKING"] = "1" |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| get_linear_schedule_with_warmup |
| ) |
| from peft import LoraConfig, get_peft_model, TaskType |
| from datasets import load_dataset |
| from tqdm.auto import tqdm |
| from multiprocessing import freeze_support |
|
|
| def main(): |
| |
| MODEL_NAME = "google/gemma-3-1b-pt" |
| DATA_FILE = "text.txt" |
| BATCH_SIZE = 12 |
| MAX_LENGTH = 128 |
| LR = 1e-5 |
| WEIGHT_DECAY = 0.01 |
| NUM_EPOCHS = 1 |
| VAL_RATIO = 0.1 |
| LORA_R = 8 |
| LORA_ALPHA = 16 |
| LORA_DROPOUT = 0.0 |
| PROJ_HIDDEN = 512 |
| PROJ_OUT = 256 |
| TEMP = 0.05 |
| OUTPUT_DIR = "stage1_simcse" |
| GRAD_CLIP_NORM = 1.0 |
| SIM_CLAMP_MIN = -10.0 |
| SIM_CLAMP_MAX = 10.0 |
| SEED = 42 |
|
|
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) |
| base_model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| attn_implementation="eager" |
| ) |
|
|
| |
| lora_cfg = LoraConfig( |
| task_type=TaskType.CAUSAL_LM, |
| inference_mode=False, |
| r=LORA_R, |
| lora_alpha=LORA_ALPHA, |
| lora_dropout=LORA_DROPOUT, |
| target_modules=["q_proj", "v_proj"], |
| ) |
| model_lora = get_peft_model(base_model, lora_cfg) |
|
|
| |
| class GemmaSimCSE(nn.Module): |
| def __init__(self, base): |
| super().__init__() |
| self.base = base |
| hs = base.config.hidden_size |
| self.proj = nn.Sequential( |
| nn.Linear(hs, PROJ_HIDDEN), |
| nn.ReLU(), |
| nn.Linear(PROJ_HIDDEN, PROJ_OUT), |
| ) |
|
|
| def forward(self, input_ids, attention_mask): |
| out = self.base( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| return_dict=True |
| ) |
| hidden = out.hidden_states[-1] |
| emb = hidden.mean(dim=1) |
| emb = torch.nan_to_num(emb, nan=0.0, posinf=1e-6, neginf=-1e-6) |
| z = self.proj(emb) |
| z = torch.nan_to_num(z, nan=0.0, posinf=1e-6, neginf=-1e-6) |
| norm = z.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6) |
| return z / norm |
|
|
| model = GemmaSimCSE(model_lora).to(device) |
| torch.autograd.set_detect_anomaly(True) |
|
|
| |
| raw = load_dataset("text", data_files={"train": DATA_FILE}, split="train") |
| raw = raw.filter(lambda x: x["text"].strip() != "") |
| split = raw.train_test_split(test_size=VAL_RATIO, seed=SEED) |
| train_ds = split["train"] |
| val_ds = split["test"] |
|
|
| |
| def tokenize_fn(batch): |
| toks = tokenizer( |
| batch["text"], |
| max_length=MAX_LENGTH, |
| truncation=True, |
| padding="max_length" |
| ) |
| return {"input_ids": toks["input_ids"], "attention_mask": toks["attention_mask"]} |
|
|
| train_ds = train_ds.map( |
| tokenize_fn, |
| batched=True, |
| batch_size=1000, |
| num_proc=4, |
| remove_columns=["text"] |
| ) |
| val_ds = val_ds.map( |
| tokenize_fn, |
| batched=True, |
| batch_size=1000, |
| num_proc=4, |
| remove_columns=["text"] |
| ) |
|
|
| train_ds.set_format(type="torch", columns=["input_ids", "attention_mask"]) |
| val_ds.set_format(type="torch", columns=["input_ids", "attention_mask"]) |
|
|
| train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) |
| val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False) |
|
|
| |
| optimizer = torch.optim.AdamW( |
| model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY |
| ) |
| total_steps = len(train_loader) * NUM_EPOCHS |
| scheduler = get_linear_schedule_with_warmup( |
| optimizer, |
| num_warmup_steps=int(0.1 * total_steps), |
| num_training_steps=total_steps |
| ) |
|
|
| |
| for epoch in range(1, NUM_EPOCHS + 1): |
| |
| model.train() |
| train_loss = 0.0 |
| for batch in tqdm(train_loader, desc=f"Train Epoch {epoch}", unit="batch"): |
| ids = batch["input_ids"].to(device) |
| mask = batch["attention_mask"].to(device) |
|
|
| emb1 = model(ids, mask) |
| emb2 = model(ids, mask) |
| emb = torch.cat([emb1, emb2], dim=0) |
| sim = (emb @ emb.T) / TEMP |
| sim = sim.clamp(SIM_CLAMP_MIN, SIM_CLAMP_MAX) |
|
|
| |
| sim.fill_diagonal_(-1e9) |
|
|
| B = emb1.size(0) |
| |
| labels = torch.cat([ |
| torch.arange(B, device=device) + B, |
| torch.arange(B, device=device) |
| ], dim=0) |
|
|
| loss = F.cross_entropy(sim, labels) |
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM) |
| optimizer.step() |
| scheduler.step() |
|
|
| train_loss += loss.item() |
|
|
| avg_train_loss = train_loss / len(train_loader) |
| print(f"Epoch {epoch} training complete. avg train loss: {avg_train_loss:.6f}") |
|
|
| |
| model.eval() |
| val_loss = 0.0 |
| with torch.no_grad(): |
| for batch in tqdm(val_loader, desc=f"Validate Epoch {epoch}", unit="batch"): |
| ids = batch["input_ids"].to(device) |
| mask = batch["attention_mask"].to(device) |
|
|
| emb1 = model(ids, mask) |
| emb2 = model(ids, mask) |
| emb = torch.cat([emb1, emb2], dim=0) |
| sim = (emb @ emb.T) / TEMP |
| sim = sim.clamp(SIM_CLAMP_MIN, SIM_CLAMP_MAX) |
| sim.fill_diagonal_(-1e9) |
|
|
| B = emb1.size(0) |
| labels = torch.cat([ |
| torch.arange(B, device=device) + B, |
| torch.arange(B, device=device) |
| ], dim=0) |
|
|
| loss = F.cross_entropy(sim, labels) |
| val_loss += loss.item() |
|
|
| avg_val_loss = val_loss / len(val_loader) |
| print(f"Epoch {epoch} validation complete. avg val loss: {avg_val_loss:.6f}") |
|
|
| |
| ckpt_dir = os.path.join(OUTPUT_DIR, f"epoch{epoch}") |
| model_lora.save_pretrained(ckpt_dir) |
| tokenizer.save_pretrained(ckpt_dir) |
|
|
| |
| final_dir = os.path.join(OUTPUT_DIR, "final") |
| os.makedirs(final_dir, exist_ok=True) |
| model_lora.save_pretrained(final_dir) |
| tokenizer.save_pretrained(final_dir) |
| print("Training and validation complete. Final model saved to", final_dir) |
|
|
| if __name__ == "__main__": |
| freeze_support() |
| main() |
|
|