Spaces:
Sleeping
Sleeping
| """ | |
| PhD Research OS β ZeroGPU Training Space | |
| ========================================== | |
| Trains the Research OS brain on ZeroGPU (H200) in micro-batches. | |
| Each @spaces.GPU call trains for ~55 seconds, saves checkpoint, resumes next call. | |
| Usage: Deploy as HF Space with ZeroGPU hardware. | |
| """ | |
| import os | |
| import json | |
| import time | |
| import torch | |
| import spaces | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import LoraConfig, PeftModel, get_peft_model | |
| from trl import SFTConfig, SFTTrainer | |
| # ============================================================ | |
| # Configuration | |
| # ============================================================ | |
| MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct" | |
| DATASET_NAME = "nkshirsa/phd-research-os-sft-data" | |
| OUTPUT_DIR = "./checkpoints" | |
| HUB_MODEL_ID = "nkshirsa/phd-research-os-brain" | |
| MAX_TRAIN_SECONDS = 55 # Leave 5s buffer from 60s ZeroGPU limit | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # ============================================================ | |
| # Global state (loaded at module level per ZeroGPU docs) | |
| # ============================================================ | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("Loading dataset...") | |
| dataset = load_dataset(DATASET_NAME) | |
| train_dataset = dataset["train"] | |
| eval_dataset = dataset["test"] | |
| print(f"Dataset loaded: {len(train_dataset)} train, {len(eval_dataset)} eval") | |
| # Track training state | |
| training_log = [] | |
| total_steps_completed = 0 | |
| # ============================================================ | |
| # Training function (runs on GPU) | |
| # ============================================================ | |
| def train_micro_batch(steps_to_train: int = 20, learning_rate: float = 2e-4, | |
| lora_r: int = 32) -> str: | |
| """ | |
| Train for a small number of steps on ZeroGPU. | |
| Each call gets ~60 seconds of H200 GPU time. | |
| """ | |
| global total_steps_completed, training_log | |
| start_time = time.time() | |
| try: | |
| # Load model with 4-bit quantization | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| # Check for existing checkpoint | |
| checkpoint_path = None | |
| if os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json")): | |
| checkpoint_path = OUTPUT_DIR | |
| log_msg = f"Resuming from checkpoint at step {total_steps_completed}" | |
| else: | |
| log_msg = "Starting fresh training" | |
| print(log_msg) | |
| # LoRA config | |
| peft_config = LoraConfig( | |
| r=lora_r, | |
| lora_alpha=16, | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj"], | |
| ) | |
| # Training config β micro batch | |
| training_args = SFTConfig( | |
| output_dir=OUTPUT_DIR, | |
| max_steps=steps_to_train, | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=4, | |
| learning_rate=learning_rate, | |
| lr_scheduler_type="cosine", | |
| warmup_steps=min(5, steps_to_train // 4), | |
| weight_decay=0.01, | |
| bf16=True, | |
| gradient_checkpointing=True, | |
| max_length=1024, | |
| logging_steps=5, | |
| logging_first_step=True, | |
| save_steps=steps_to_train, # Save at end of micro-batch | |
| save_total_limit=2, | |
| disable_tqdm=True, | |
| report_to=[], | |
| seed=42, | |
| # Don't push every micro-batch β we push manually at the end | |
| push_to_hub=False, | |
| ) | |
| # Initialize trainer | |
| if checkpoint_path: | |
| # Resume: load base model + existing adapter | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| quantization_config=bnb_config, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ) | |
| model = PeftModel.from_pretrained(model, checkpoint_path, is_trainable=True) | |
| trainer = SFTTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| processing_class=tokenizer, | |
| ) | |
| else: | |
| # Fresh start | |
| training_args.model_init_kwargs = { | |
| "quantization_config": bnb_config, | |
| "torch_dtype": torch.bfloat16, | |
| } | |
| trainer = SFTTrainer( | |
| model=MODEL_NAME, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| peft_config=peft_config, | |
| processing_class=tokenizer, | |
| ) | |
| # Train | |
| result = trainer.train() | |
| # Save checkpoint | |
| trainer.save_model(OUTPUT_DIR) | |
| tokenizer.save_pretrained(OUTPUT_DIR) | |
| elapsed = time.time() - start_time | |
| total_steps_completed += steps_to_train | |
| # Log results | |
| metrics = { | |
| "steps_this_batch": steps_to_train, | |
| "total_steps": total_steps_completed, | |
| "train_loss": result.metrics.get("train_loss", "N/A"), | |
| "elapsed_seconds": round(elapsed, 1), | |
| "learning_rate": learning_rate, | |
| "lora_r": lora_r, | |
| } | |
| training_log.append(metrics) | |
| summary = f"""β **Micro-batch complete!** | |
| | Metric | Value | | |
| |--------|-------| | |
| | Steps trained | {steps_to_train} | | |
| | Total steps | {total_steps_completed} | | |
| | Training loss | {result.metrics.get('train_loss', 'N/A')} | | |
| | Time | {elapsed:.1f}s | | |
| | Checkpoint | `{OUTPUT_DIR}` | | |
| *Call again to continue training. Each call adds more steps.* | |
| """ | |
| return summary | |
| except Exception as e: | |
| elapsed = time.time() - start_time | |
| error_msg = f"β Training error after {elapsed:.1f}s: {str(e)}" | |
| training_log.append({"error": str(e), "elapsed": elapsed}) | |
| return error_msg | |
| def evaluate_model() -> str: | |
| """Run evaluation on the test set.""" | |
| if not os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json")): | |
| return "β No checkpoint found. Train first." | |
| try: | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| quantization_config=bnb_config, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ) | |
| model = PeftModel.from_pretrained(model, OUTPUT_DIR) | |
| training_args = SFTConfig( | |
| output_dir="./eval_tmp", | |
| per_device_eval_batch_size=1, | |
| bf16=True, | |
| disable_tqdm=True, | |
| report_to=[], | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, | |
| args=training_args, | |
| eval_dataset=eval_dataset, | |
| processing_class=tokenizer, | |
| ) | |
| metrics = trainer.evaluate() | |
| summary = f"""β **Evaluation complete!** | |
| | Metric | Value | | |
| |--------|-------| | |
| | Eval Loss | {metrics.get('eval_loss', 'N/A'):.4f} | | |
| | Eval Samples | {metrics.get('eval_samples', len(eval_dataset))} | | |
| | Total Train Steps | {total_steps_completed} | | |
| """ | |
| return summary | |
| except Exception as e: | |
| return f"β Evaluation error: {str(e)}" | |
| def push_to_hub() -> str: | |
| """Push the trained adapter to HF Hub.""" | |
| if not os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json")): | |
| return "β No checkpoint found. Train first." | |
| try: | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| quantization_config=bnb_config, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ) | |
| model = PeftModel.from_pretrained(model, OUTPUT_DIR) | |
| model.push_to_hub(HUB_MODEL_ID, commit_message=f"ZeroGPU training: {total_steps_completed} steps") | |
| tokenizer.push_to_hub(HUB_MODEL_ID) | |
| return f"""β **Model pushed to Hub!** | |
| π [https://huggingface.co/{HUB_MODEL_ID}](https://huggingface.co/{HUB_MODEL_ID}) | |
| Total steps trained: {total_steps_completed} | |
| """ | |
| except Exception as e: | |
| return f"β Push error: {str(e)}" | |
| def get_training_log(): | |
| """Show training history.""" | |
| if not training_log: | |
| return "No training runs yet. Click 'Train' to start." | |
| lines = ["| Run | Steps | Loss | Time |", "|-----|-------|------|------|"] | |
| for i, entry in enumerate(training_log): | |
| if "error" in entry: | |
| lines.append(f"| {i+1} | ERROR | β | {entry.get('elapsed', '?')}s |") | |
| else: | |
| lines.append(f"| {i+1} | {entry.get('total_steps', '?')} | {entry.get('train_loss', '?')} | {entry.get('elapsed_seconds', '?')}s |") | |
| return "\n".join(lines) | |
| def test_inference(prompt: str) -> str: | |
| """Test the trained model with a prompt.""" | |
| if not os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json")): | |
| return "β No checkpoint found. Train first." | |
| try: | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| quantization_config=bnb_config, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ) | |
| model = PeftModel.from_pretrained(model, OUTPUT_DIR) | |
| model.eval() | |
| messages = [ | |
| {"role": "system", "content": "You are a scientific claim extractor. Extract claims as JSON."}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(text, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| temperature=0.1, | |
| do_sample=True, | |
| top_p=0.95, | |
| ) | |
| response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
| return response | |
| except Exception as e: | |
| return f"β Inference error: {str(e)}" | |
| # ============================================================ | |
| # Gradio UI | |
| # ============================================================ | |
| with gr.Blocks(title="PhD Research OS β Training") as app: | |
| gr.Markdown(f""" | |
| # π§ PhD Research OS β Model Training (ZeroGPU) | |
| **Base Model**: `{MODEL_NAME}` | |
| **Dataset**: `{DATASET_NAME}` ({len(train_dataset)} train / {len(eval_dataset)} eval) | |
| **Method**: QLoRA (4-bit NF4) on ZeroGPU H200 | |
| Each "Train" click runs ~20 gradient steps in ~55 seconds of GPU time. | |
| Click multiple times to accumulate training. Push to Hub when satisfied. | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("ποΈ Train"): | |
| with gr.Row(): | |
| steps_input = gr.Slider(5, 50, value=20, step=5, label="Steps per micro-batch") | |
| lr_input = gr.Slider(1e-5, 5e-4, value=2e-4, step=1e-5, label="Learning Rate") | |
| rank_input = gr.Slider(8, 64, value=32, step=8, label="LoRA Rank") | |
| train_btn = gr.Button("ποΈ Train Micro-Batch (uses ~60s GPU)", variant="primary", size="lg") | |
| train_output = gr.Markdown() | |
| train_btn.click(train_micro_batch, inputs=[steps_input, lr_input, rank_input], outputs=train_output) | |
| gr.Markdown("---") | |
| log_btn = gr.Button("π Show Training Log") | |
| log_output = gr.Markdown() | |
| log_btn.click(get_training_log, outputs=log_output) | |
| with gr.Tab("π Evaluate"): | |
| eval_btn = gr.Button("π Run Evaluation", variant="primary") | |
| eval_output = gr.Markdown() | |
| eval_btn.click(evaluate_model, outputs=eval_output) | |
| with gr.Tab("π§ͺ Test"): | |
| test_prompt = gr.Textbox( | |
| label="Test Prompt", | |
| value="Extract claims from: The LOD was 0.8 fM in 10 mM PBS (n=5, p<0.001). Sensitivity may decrease at physiological ionic strength.", | |
| lines=3, | |
| ) | |
| test_btn = gr.Button("π§ͺ Run Inference", variant="primary") | |
| test_output = gr.Textbox(label="Model Output", lines=10) | |
| test_btn.click(test_inference, inputs=test_prompt, outputs=test_output) | |
| with gr.Tab("π Push to Hub"): | |
| gr.Markdown(f"Push the trained LoRA adapter to [{HUB_MODEL_ID}](https://huggingface.co/{HUB_MODEL_ID})") | |
| push_btn = gr.Button("π Push to Hub", variant="primary") | |
| push_output = gr.Markdown() | |
| push_btn.click(push_to_hub, outputs=push_output) | |
| if __name__ == "__main__": | |
| app.launch() | |