nkshirsa's picture
Add app.py
7991b3d verified
"""
PhD Research OS β€” ZeroGPU Training Space
==========================================
Trains the Research OS brain on ZeroGPU (H200) in micro-batches.
Each @spaces.GPU call trains for ~55 seconds, saves checkpoint, resumes next call.
Usage: Deploy as HF Space with ZeroGPU hardware.
"""
import os
import json
import time
import torch
import spaces
import gradio as gr
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, PeftModel, get_peft_model
from trl import SFTConfig, SFTTrainer
# ============================================================
# Configuration
# ============================================================
MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"
DATASET_NAME = "nkshirsa/phd-research-os-sft-data"
OUTPUT_DIR = "./checkpoints"
HUB_MODEL_ID = "nkshirsa/phd-research-os-brain"
MAX_TRAIN_SECONDS = 55 # Leave 5s buffer from 60s ZeroGPU limit
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ============================================================
# Global state (loaded at module level per ZeroGPU docs)
# ============================================================
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Loading dataset...")
dataset = load_dataset(DATASET_NAME)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
print(f"Dataset loaded: {len(train_dataset)} train, {len(eval_dataset)} eval")
# Track training state
training_log = []
total_steps_completed = 0
# ============================================================
# Training function (runs on GPU)
# ============================================================
@spaces.GPU(duration=60)
def train_micro_batch(steps_to_train: int = 20, learning_rate: float = 2e-4,
lora_r: int = 32) -> str:
"""
Train for a small number of steps on ZeroGPU.
Each call gets ~60 seconds of H200 GPU time.
"""
global total_steps_completed, training_log
start_time = time.time()
try:
# Load model with 4-bit quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
# Check for existing checkpoint
checkpoint_path = None
if os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json")):
checkpoint_path = OUTPUT_DIR
log_msg = f"Resuming from checkpoint at step {total_steps_completed}"
else:
log_msg = "Starting fresh training"
print(log_msg)
# LoRA config
peft_config = LoraConfig(
r=lora_r,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
)
# Training config β€” micro batch
training_args = SFTConfig(
output_dir=OUTPUT_DIR,
max_steps=steps_to_train,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=learning_rate,
lr_scheduler_type="cosine",
warmup_steps=min(5, steps_to_train // 4),
weight_decay=0.01,
bf16=True,
gradient_checkpointing=True,
max_length=1024,
logging_steps=5,
logging_first_step=True,
save_steps=steps_to_train, # Save at end of micro-batch
save_total_limit=2,
disable_tqdm=True,
report_to=[],
seed=42,
# Don't push every micro-batch β€” we push manually at the end
push_to_hub=False,
)
# Initialize trainer
if checkpoint_path:
# Resume: load base model + existing adapter
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model = PeftModel.from_pretrained(model, checkpoint_path, is_trainable=True)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
processing_class=tokenizer,
)
else:
# Fresh start
training_args.model_init_kwargs = {
"quantization_config": bnb_config,
"torch_dtype": torch.bfloat16,
}
trainer = SFTTrainer(
model=MODEL_NAME,
args=training_args,
train_dataset=train_dataset,
peft_config=peft_config,
processing_class=tokenizer,
)
# Train
result = trainer.train()
# Save checkpoint
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
elapsed = time.time() - start_time
total_steps_completed += steps_to_train
# Log results
metrics = {
"steps_this_batch": steps_to_train,
"total_steps": total_steps_completed,
"train_loss": result.metrics.get("train_loss", "N/A"),
"elapsed_seconds": round(elapsed, 1),
"learning_rate": learning_rate,
"lora_r": lora_r,
}
training_log.append(metrics)
summary = f"""βœ… **Micro-batch complete!**
| Metric | Value |
|--------|-------|
| Steps trained | {steps_to_train} |
| Total steps | {total_steps_completed} |
| Training loss | {result.metrics.get('train_loss', 'N/A')} |
| Time | {elapsed:.1f}s |
| Checkpoint | `{OUTPUT_DIR}` |
*Call again to continue training. Each call adds more steps.*
"""
return summary
except Exception as e:
elapsed = time.time() - start_time
error_msg = f"❌ Training error after {elapsed:.1f}s: {str(e)}"
training_log.append({"error": str(e), "elapsed": elapsed})
return error_msg
@spaces.GPU(duration=60)
def evaluate_model() -> str:
"""Run evaluation on the test set."""
if not os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json")):
return "❌ No checkpoint found. Train first."
try:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model = PeftModel.from_pretrained(model, OUTPUT_DIR)
training_args = SFTConfig(
output_dir="./eval_tmp",
per_device_eval_batch_size=1,
bf16=True,
disable_tqdm=True,
report_to=[],
)
trainer = SFTTrainer(
model=model,
args=training_args,
eval_dataset=eval_dataset,
processing_class=tokenizer,
)
metrics = trainer.evaluate()
summary = f"""βœ… **Evaluation complete!**
| Metric | Value |
|--------|-------|
| Eval Loss | {metrics.get('eval_loss', 'N/A'):.4f} |
| Eval Samples | {metrics.get('eval_samples', len(eval_dataset))} |
| Total Train Steps | {total_steps_completed} |
"""
return summary
except Exception as e:
return f"❌ Evaluation error: {str(e)}"
@spaces.GPU(duration=120)
def push_to_hub() -> str:
"""Push the trained adapter to HF Hub."""
if not os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json")):
return "❌ No checkpoint found. Train first."
try:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model = PeftModel.from_pretrained(model, OUTPUT_DIR)
model.push_to_hub(HUB_MODEL_ID, commit_message=f"ZeroGPU training: {total_steps_completed} steps")
tokenizer.push_to_hub(HUB_MODEL_ID)
return f"""βœ… **Model pushed to Hub!**
πŸ”— [https://huggingface.co/{HUB_MODEL_ID}](https://huggingface.co/{HUB_MODEL_ID})
Total steps trained: {total_steps_completed}
"""
except Exception as e:
return f"❌ Push error: {str(e)}"
def get_training_log():
"""Show training history."""
if not training_log:
return "No training runs yet. Click 'Train' to start."
lines = ["| Run | Steps | Loss | Time |", "|-----|-------|------|------|"]
for i, entry in enumerate(training_log):
if "error" in entry:
lines.append(f"| {i+1} | ERROR | β€” | {entry.get('elapsed', '?')}s |")
else:
lines.append(f"| {i+1} | {entry.get('total_steps', '?')} | {entry.get('train_loss', '?')} | {entry.get('elapsed_seconds', '?')}s |")
return "\n".join(lines)
@spaces.GPU(duration=60)
def test_inference(prompt: str) -> str:
"""Test the trained model with a prompt."""
if not os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json")):
return "❌ No checkpoint found. Train first."
try:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model = PeftModel.from_pretrained(model, OUTPUT_DIR)
model.eval()
messages = [
{"role": "system", "content": "You are a scientific claim extractor. Extract claims as JSON."},
{"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.1,
do_sample=True,
top_p=0.95,
)
response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return response
except Exception as e:
return f"❌ Inference error: {str(e)}"
# ============================================================
# Gradio UI
# ============================================================
with gr.Blocks(title="PhD Research OS β€” Training") as app:
gr.Markdown(f"""
# 🧠 PhD Research OS β€” Model Training (ZeroGPU)
**Base Model**: `{MODEL_NAME}`
**Dataset**: `{DATASET_NAME}` ({len(train_dataset)} train / {len(eval_dataset)} eval)
**Method**: QLoRA (4-bit NF4) on ZeroGPU H200
Each "Train" click runs ~20 gradient steps in ~55 seconds of GPU time.
Click multiple times to accumulate training. Push to Hub when satisfied.
""")
with gr.Tabs():
with gr.Tab("πŸ‹οΈ Train"):
with gr.Row():
steps_input = gr.Slider(5, 50, value=20, step=5, label="Steps per micro-batch")
lr_input = gr.Slider(1e-5, 5e-4, value=2e-4, step=1e-5, label="Learning Rate")
rank_input = gr.Slider(8, 64, value=32, step=8, label="LoRA Rank")
train_btn = gr.Button("πŸ‹οΈ Train Micro-Batch (uses ~60s GPU)", variant="primary", size="lg")
train_output = gr.Markdown()
train_btn.click(train_micro_batch, inputs=[steps_input, lr_input, rank_input], outputs=train_output)
gr.Markdown("---")
log_btn = gr.Button("πŸ“‹ Show Training Log")
log_output = gr.Markdown()
log_btn.click(get_training_log, outputs=log_output)
with gr.Tab("πŸ“Š Evaluate"):
eval_btn = gr.Button("πŸ“Š Run Evaluation", variant="primary")
eval_output = gr.Markdown()
eval_btn.click(evaluate_model, outputs=eval_output)
with gr.Tab("πŸ§ͺ Test"):
test_prompt = gr.Textbox(
label="Test Prompt",
value="Extract claims from: The LOD was 0.8 fM in 10 mM PBS (n=5, p<0.001). Sensitivity may decrease at physiological ionic strength.",
lines=3,
)
test_btn = gr.Button("πŸ§ͺ Run Inference", variant="primary")
test_output = gr.Textbox(label="Model Output", lines=10)
test_btn.click(test_inference, inputs=test_prompt, outputs=test_output)
with gr.Tab("πŸš€ Push to Hub"):
gr.Markdown(f"Push the trained LoRA adapter to [{HUB_MODEL_ID}](https://huggingface.co/{HUB_MODEL_ID})")
push_btn = gr.Button("πŸš€ Push to Hub", variant="primary")
push_output = gr.Markdown()
push_btn.click(push_to_hub, outputs=push_output)
if __name__ == "__main__":
app.launch()