""" PhD Research OS — ZeroGPU Training Space ========================================== Trains the Research OS brain on ZeroGPU (H200) in micro-batches. Each @spaces.GPU call trains for ~55 seconds, saves checkpoint, resumes next call. Usage: Deploy as HF Space with ZeroGPU hardware. """ import os import json import time import torch import spaces import gradio as gr from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import LoraConfig, PeftModel, get_peft_model from trl import SFTConfig, SFTTrainer # ============================================================ # Configuration # ============================================================ MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct" DATASET_NAME = "nkshirsa/phd-research-os-sft-data" OUTPUT_DIR = "./checkpoints" HUB_MODEL_ID = "nkshirsa/phd-research-os-brain" MAX_TRAIN_SECONDS = 55 # Leave 5s buffer from 60s ZeroGPU limit os.makedirs(OUTPUT_DIR, exist_ok=True) # ============================================================ # Global state (loaded at module level per ZeroGPU docs) # ============================================================ print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("Loading dataset...") dataset = load_dataset(DATASET_NAME) train_dataset = dataset["train"] eval_dataset = dataset["test"] print(f"Dataset loaded: {len(train_dataset)} train, {len(eval_dataset)} eval") # Track training state training_log = [] total_steps_completed = 0 # ============================================================ # Training function (runs on GPU) # ============================================================ @spaces.GPU(duration=60) def train_micro_batch(steps_to_train: int = 20, learning_rate: float = 2e-4, lora_r: int = 32) -> str: """ Train for a small number of steps on ZeroGPU. Each call gets ~60 seconds of H200 GPU time. """ global total_steps_completed, training_log start_time = time.time() try: # Load model with 4-bit quantization bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) # Check for existing checkpoint checkpoint_path = None if os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json")): checkpoint_path = OUTPUT_DIR log_msg = f"Resuming from checkpoint at step {total_steps_completed}" else: log_msg = "Starting fresh training" print(log_msg) # LoRA config peft_config = LoraConfig( r=lora_r, lora_alpha=16, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], ) # Training config — micro batch training_args = SFTConfig( output_dir=OUTPUT_DIR, max_steps=steps_to_train, per_device_train_batch_size=1, gradient_accumulation_steps=4, learning_rate=learning_rate, lr_scheduler_type="cosine", warmup_steps=min(5, steps_to_train // 4), weight_decay=0.01, bf16=True, gradient_checkpointing=True, max_length=1024, logging_steps=5, logging_first_step=True, save_steps=steps_to_train, # Save at end of micro-batch save_total_limit=2, disable_tqdm=True, report_to=[], seed=42, # Don't push every micro-batch — we push manually at the end push_to_hub=False, ) # Initialize trainer if checkpoint_path: # Resume: load base model + existing adapter model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=bnb_config, torch_dtype=torch.bfloat16, device_map="auto", ) model = PeftModel.from_pretrained(model, checkpoint_path, is_trainable=True) trainer = SFTTrainer( model=model, args=training_args, train_dataset=train_dataset, processing_class=tokenizer, ) else: # Fresh start training_args.model_init_kwargs = { "quantization_config": bnb_config, "torch_dtype": torch.bfloat16, } trainer = SFTTrainer( model=MODEL_NAME, args=training_args, train_dataset=train_dataset, peft_config=peft_config, processing_class=tokenizer, ) # Train result = trainer.train() # Save checkpoint trainer.save_model(OUTPUT_DIR) tokenizer.save_pretrained(OUTPUT_DIR) elapsed = time.time() - start_time total_steps_completed += steps_to_train # Log results metrics = { "steps_this_batch": steps_to_train, "total_steps": total_steps_completed, "train_loss": result.metrics.get("train_loss", "N/A"), "elapsed_seconds": round(elapsed, 1), "learning_rate": learning_rate, "lora_r": lora_r, } training_log.append(metrics) summary = f"""✅ **Micro-batch complete!** | Metric | Value | |--------|-------| | Steps trained | {steps_to_train} | | Total steps | {total_steps_completed} | | Training loss | {result.metrics.get('train_loss', 'N/A')} | | Time | {elapsed:.1f}s | | Checkpoint | `{OUTPUT_DIR}` | *Call again to continue training. Each call adds more steps.* """ return summary except Exception as e: elapsed = time.time() - start_time error_msg = f"❌ Training error after {elapsed:.1f}s: {str(e)}" training_log.append({"error": str(e), "elapsed": elapsed}) return error_msg @spaces.GPU(duration=60) def evaluate_model() -> str: """Run evaluation on the test set.""" if not os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json")): return "❌ No checkpoint found. Train first." try: bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=bnb_config, torch_dtype=torch.bfloat16, device_map="auto", ) model = PeftModel.from_pretrained(model, OUTPUT_DIR) training_args = SFTConfig( output_dir="./eval_tmp", per_device_eval_batch_size=1, bf16=True, disable_tqdm=True, report_to=[], ) trainer = SFTTrainer( model=model, args=training_args, eval_dataset=eval_dataset, processing_class=tokenizer, ) metrics = trainer.evaluate() summary = f"""✅ **Evaluation complete!** | Metric | Value | |--------|-------| | Eval Loss | {metrics.get('eval_loss', 'N/A'):.4f} | | Eval Samples | {metrics.get('eval_samples', len(eval_dataset))} | | Total Train Steps | {total_steps_completed} | """ return summary except Exception as e: return f"❌ Evaluation error: {str(e)}" @spaces.GPU(duration=120) def push_to_hub() -> str: """Push the trained adapter to HF Hub.""" if not os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json")): return "❌ No checkpoint found. Train first." try: bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=bnb_config, torch_dtype=torch.bfloat16, device_map="auto", ) model = PeftModel.from_pretrained(model, OUTPUT_DIR) model.push_to_hub(HUB_MODEL_ID, commit_message=f"ZeroGPU training: {total_steps_completed} steps") tokenizer.push_to_hub(HUB_MODEL_ID) return f"""✅ **Model pushed to Hub!** 🔗 [https://huggingface.co/{HUB_MODEL_ID}](https://huggingface.co/{HUB_MODEL_ID}) Total steps trained: {total_steps_completed} """ except Exception as e: return f"❌ Push error: {str(e)}" def get_training_log(): """Show training history.""" if not training_log: return "No training runs yet. Click 'Train' to start." lines = ["| Run | Steps | Loss | Time |", "|-----|-------|------|------|"] for i, entry in enumerate(training_log): if "error" in entry: lines.append(f"| {i+1} | ERROR | — | {entry.get('elapsed', '?')}s |") else: lines.append(f"| {i+1} | {entry.get('total_steps', '?')} | {entry.get('train_loss', '?')} | {entry.get('elapsed_seconds', '?')}s |") return "\n".join(lines) @spaces.GPU(duration=60) def test_inference(prompt: str) -> str: """Test the trained model with a prompt.""" if not os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json")): return "❌ No checkpoint found. Train first." try: bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=bnb_config, torch_dtype=torch.bfloat16, device_map="auto", ) model = PeftModel.from_pretrained(model, OUTPUT_DIR) model.eval() messages = [ {"role": "system", "content": "You are a scientific claim extractor. Extract claims as JSON."}, {"role": "user", "content": prompt}, ] text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(text, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=512, temperature=0.1, do_sample=True, top_p=0.95, ) response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) return response except Exception as e: return f"❌ Inference error: {str(e)}" # ============================================================ # Gradio UI # ============================================================ with gr.Blocks(title="PhD Research OS — Training") as app: gr.Markdown(f""" # 🧠 PhD Research OS — Model Training (ZeroGPU) **Base Model**: `{MODEL_NAME}` **Dataset**: `{DATASET_NAME}` ({len(train_dataset)} train / {len(eval_dataset)} eval) **Method**: QLoRA (4-bit NF4) on ZeroGPU H200 Each "Train" click runs ~20 gradient steps in ~55 seconds of GPU time. Click multiple times to accumulate training. Push to Hub when satisfied. """) with gr.Tabs(): with gr.Tab("🏋️ Train"): with gr.Row(): steps_input = gr.Slider(5, 50, value=20, step=5, label="Steps per micro-batch") lr_input = gr.Slider(1e-5, 5e-4, value=2e-4, step=1e-5, label="Learning Rate") rank_input = gr.Slider(8, 64, value=32, step=8, label="LoRA Rank") train_btn = gr.Button("🏋️ Train Micro-Batch (uses ~60s GPU)", variant="primary", size="lg") train_output = gr.Markdown() train_btn.click(train_micro_batch, inputs=[steps_input, lr_input, rank_input], outputs=train_output) gr.Markdown("---") log_btn = gr.Button("📋 Show Training Log") log_output = gr.Markdown() log_btn.click(get_training_log, outputs=log_output) with gr.Tab("📊 Evaluate"): eval_btn = gr.Button("📊 Run Evaluation", variant="primary") eval_output = gr.Markdown() eval_btn.click(evaluate_model, outputs=eval_output) with gr.Tab("🧪 Test"): test_prompt = gr.Textbox( label="Test Prompt", value="Extract claims from: The LOD was 0.8 fM in 10 mM PBS (n=5, p<0.001). Sensitivity may decrease at physiological ionic strength.", lines=3, ) test_btn = gr.Button("🧪 Run Inference", variant="primary") test_output = gr.Textbox(label="Model Output", lines=10) test_btn.click(test_inference, inputs=test_prompt, outputs=test_output) with gr.Tab("🚀 Push to Hub"): gr.Markdown(f"Push the trained LoRA adapter to [{HUB_MODEL_ID}](https://huggingface.co/{HUB_MODEL_ID})") push_btn = gr.Button("🚀 Push to Hub", variant="primary") push_output = gr.Markdown() push_btn.click(push_to_hub, outputs=push_output) if __name__ == "__main__": app.launch()