| """ |
| Q-GPT Training Script |
| Train the quantum head on GPT outputs. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader, Dataset |
| from tqdm import tqdm |
| import json |
| import os |
|
|
| from quantum_head import QuantumHead, load_qgpt |
|
|
|
|
| class ConfidenceDataset(Dataset): |
| """Dataset for training quantum confidence head.""" |
| |
| def __init__(self, data_path: str, tokenizer, max_length: int = 512): |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.data = [] |
| |
| |
| with open(data_path, 'r') as f: |
| for line in f: |
| item = json.loads(line) |
| self.data.append(item) |
| |
| def __len__(self): |
| return len(self.data) |
| |
| def __getitem__(self, idx): |
| item = self.data[idx] |
| |
| |
| encoding = self.tokenizer( |
| item["text"], |
| truncation=True, |
| max_length=self.max_length, |
| padding="max_length", |
| return_tensors="pt" |
| ) |
| |
| return { |
| "input_ids": encoding["input_ids"].squeeze(), |
| "attention_mask": encoding["attention_mask"].squeeze(), |
| "confidence_label": torch.tensor(item.get("confidence", 0.5)), |
| "is_correct": torch.tensor(float(item.get("is_correct", True))), |
| } |
|
|
|
|
| def train_quantum_head( |
| model_name: str = "squ11z1/gpt-oss-9b-reasoning", |
| train_data_path: str = None, |
| output_dir: str = "./q_gpt_trained", |
| epochs: int = 3, |
| batch_size: int = 4, |
| learning_rate: float = 1e-4, |
| device: str = "cuda", |
| ): |
| """ |
| Train the quantum head on confidence estimation. |
| |
| Args: |
| model_name: Base model name |
| train_data_path: Path to training data (jsonl with text, confidence, is_correct) |
| output_dir: Where to save trained weights |
| epochs: Number of training epochs |
| batch_size: Batch size |
| learning_rate: Learning rate for quantum head |
| device: Device to train on |
| """ |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
| os.makedirs(output_dir, exist_ok=True) |
| |
| print(f"Loading model: {model_name}") |
| |
| |
| base_model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
| base_model.eval() |
| for param in base_model.parameters(): |
| param.requires_grad = False |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| |
| hidden_size = base_model.config.hidden_size |
| quantum_head = QuantumHead(hidden_size=hidden_size).to(device) |
| |
| |
| optimizer = torch.optim.AdamW(quantum_head.parameters(), lr=learning_rate) |
| |
| |
| confidence_loss_fn = nn.BCELoss() |
| correctness_loss_fn = nn.BCELoss() |
| |
| |
| if train_data_path and os.path.exists(train_data_path): |
| dataset = ConfidenceDataset(train_data_path, tokenizer) |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
| |
| for epoch in range(epochs): |
| quantum_head.train() |
| total_loss = 0 |
| |
| for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"): |
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
| confidence_labels = batch["confidence_label"].to(device) |
| correctness_labels = batch["is_correct"].to(device) |
| |
| |
| with torch.no_grad(): |
| outputs = base_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_hidden_states=True |
| ) |
| hidden_states = outputs.hidden_states[-1] |
| |
| |
| qout = quantum_head(hidden_states.to(device)) |
| |
| |
| conf_loss = confidence_loss_fn(qout["confidence"], confidence_labels) |
| |
| |
| correct_loss = correctness_loss_fn(qout["confidence"], correctness_labels) |
| |
| loss = 0.5 * conf_loss + 0.5 * correct_loss |
| |
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| |
| total_loss += loss.item() |
| |
| avg_loss = total_loss / len(dataloader) |
| print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f}") |
| else: |
| print("No training data provided. Saving untrained quantum head.") |
| |
| |
| save_path = os.path.join(output_dir, "quantum_head.pt") |
| torch.save(quantum_head.state_dict(), save_path) |
| print(f"Saved quantum head to {save_path}") |
| |
| return quantum_head |
|
|
|
|
| def create_synthetic_training_data( |
| model_name: str, |
| output_path: str, |
| num_samples: int = 1000, |
| ): |
| """Create synthetic training data from model predictions.""" |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import random |
| |
| print("Creating synthetic training data...") |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| |
| |
| prompts = [ |
| "What is 2 + 2?", |
| "Explain quantum mechanics.", |
| "Who was the first president of USA?", |
| "Solve: x^2 - 4 = 0", |
| "What is the capital of France?", |
| "Explain machine learning.", |
| "What is consciousness?", |
| "Calculate 15% of 200.", |
| ] |
| |
| data = [] |
| |
| for i in tqdm(range(num_samples)): |
| prompt = random.choice(prompts) |
| |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=50, |
| do_sample=True, |
| temperature=0.7, |
| ) |
| |
| text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| |
| is_factual = any(kw in prompt.lower() for kw in ["what is", "who", "calculate", "solve"]) |
| confidence = random.uniform(0.7, 0.95) if is_factual else random.uniform(0.4, 0.7) |
| |
| data.append({ |
| "text": text, |
| "confidence": confidence, |
| "is_correct": confidence > 0.5, |
| }) |
| |
| with open(output_path, 'w') as f: |
| for item in data: |
| f.write(json.dumps(item) + '\n') |
| |
| print(f"Created {len(data)} samples at {output_path}") |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model", default="squ11z1/gpt-oss-9b-reasoning") |
| parser.add_argument("--data", default=None) |
| parser.add_argument("--output", default="./q_gpt_trained") |
| parser.add_argument("--epochs", type=int, default=3) |
| parser.add_argument("--create-data", action="store_true") |
| |
| args = parser.parse_args() |
| |
| if args.create_data: |
| create_synthetic_training_data(args.model, args.data or "train_data.jsonl") |
| else: |
| train_quantum_head( |
| model_name=args.model, |
| train_data_path=args.data, |
| output_dir=args.output, |
| epochs=args.epochs, |
| ) |
|
|