Ashira Pitchayapakayakul
feat: clone working ashirato ZeroGPU app.py — 2nd PRO endpoint
87b89b2
"""Surrogate-1 ZeroGPU — chat + synth-batch endpoints (gr.Blocks 4.44).
Two functions exposed via Gradio API:
• POST /run/respond — single chat completion (also UI tab)
• POST /run/synth_batch — Magpie-style synthetic training pair batch
synth_batch is hit by ~/.surrogate/bin/v2/synth-puller.sh every 5 min
on the bulk Space, drains free PRO ZeroGPU budget into training data.
Each call returns up to 20 JSONL pairs as a single string.
Earlier ChatInterface attempts hit a starlette TemplateResponse failure
during gradio's static-route init. gr.Blocks with explicit api_name on
each click avoids the same code path and exposes both endpoints cleanly.
Backbone: Qwen2.5-Coder-7B-Instruct + Surrogate-1 v1 LoRA, bnb int4.
"""
import json
import os
import gradio as gr
import spaces
import torch
BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-7B-Instruct")
LORA_REPO = os.environ.get("LORA_REPO", "axentx/surrogate-1-coder-7b-lora-v1")
HF_TOKEN = os.environ.get("HF_TOKEN", "")
SYSTEM = ("You are Surrogate-1, expert DevSecOps + SRE + coding agent. "
"Cite real APIs only. Say IDK rather than confabulate.")
DOMAIN_HINTS = {
"code-python": "Python coding tasks, idiomatic, type-hinted",
"code-typescript": "TypeScript / React / Node tasks, strict types",
"code-rust": "Rust ownership, async, performance",
"code-go": "Go concurrency, stdlib, microservices",
"devops-tf": "Terraform AWS/GCP modules, best practices",
"devops-k8s": "Kubernetes manifests, helm, troubleshooting",
"devops-cdk": "AWS CDK constructs, TypeScript",
"ci-github": "GitHub Actions workflows, reusable, secure",
"sec-iam": "IAM least-privilege policies, AssumeRole",
"sec-cve": "CVE remediation, SCA, dependency hygiene",
"sre-runbook": "Incident runbooks, on-call, postmortems",
"sre-slo": "SLO/SLI/error budgets, observability",
"data-sql": "SQL queries, indexes, query plans, optimisation",
"ai-eng": "RAG, vLLM, fine-tuning, evals",
"api-rest": "REST API design, OpenAPI, idempotency",
"test-pytest": "pytest fixtures, parametrize, markers",
}
_model = None
_tokenizer = None
def _load():
global _model, _tokenizer
if _model is not None:
return _model, _tokenizer
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig)
_tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL, token=HF_TOKEN or None, trust_remote_code=True)
if _tokenizer.pad_token_id is None:
_tokenizer.pad_token_id = _tokenizer.eos_token_id
bnb = BitsAndBytesConfig(load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True)
_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL, token=HF_TOKEN or None, trust_remote_code=True,
device_map="cuda", quantization_config=bnb)
if LORA_REPO:
try:
from peft import PeftModel
_model = PeftModel.from_pretrained(
_model, LORA_REPO, token=HF_TOKEN or None)
print(f"[ok] LoRA: {LORA_REPO}")
except Exception as e:
print(f"[skip] LoRA: {e}")
return _model, _tokenizer
def _generate(prompt: str, max_tokens: int = 768,
temperature: float = 0.7) -> str:
model, tokenizer = _load()
msgs = [{"role": "system", "content": SYSTEM},
{"role": "user", "content": prompt}]
chat = tokenizer.apply_chat_template(
msgs, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(chat, return_tensors="pt", truncation=True,
max_length=8000).to("cuda")
out = model.generate(
**inputs,
max_new_tokens=max_tokens, temperature=temperature, do_sample=True,
top_p=0.9, pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id)
return tokenizer.decode(
out[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True).strip()
@spaces.GPU(duration=300)
def respond(message: str) -> str:
if not message or not message.strip():
return "(empty)"
return _generate(message, max_tokens=768, temperature=0.4)
@spaces.GPU(duration=600)
def synth_batch(domain: str, count) -> str:
"""Magpie-style synthetic pair generation. Returns N JSONL lines."""
domain = (domain or "code-python").strip()
try:
count = int(count or 12)
except (TypeError, ValueError):
count = 12
count = max(1, min(20, count))
hint = DOMAIN_HINTS.get(domain, domain)
seed = (f"Generate ONE realistic technical question a senior engineer "
f"would ask about {hint}. Output JUST the question text, no "
f"preamble or quotes. Make it specific and answerable in "
f"200-500 words with code/config examples.")
pairs = []
for _ in range(count):
try:
instruction = _generate(seed, max_tokens=200, temperature=0.95)
instruction = (instruction.split("\n")[0]
.strip().strip('"').strip("'")[:600])
if len(instruction) < 30:
continue
response = _generate(instruction, max_tokens=900,
temperature=0.4)
if len(response) < 80:
continue
pairs.append(json.dumps({
"prompt": instruction,
"response": response,
"source": f"surrogate-1-zero-gpu/synth-{domain}",
"meta": {"domain": domain, "magpie": True},
}, ensure_ascii=False))
except Exception as e:
print(f"[synth_batch] err: {e}")
continue
return "\n".join(pairs)
with gr.Blocks(title="Surrogate-1 ZeroGPU") as demo:
gr.Markdown("# Surrogate-1 (7B + v1 LoRA, ZeroGPU A10G)")
gr.Markdown(
"Qwen2.5-Coder-7B + Surrogate-1 v1 LoRA on free PRO ZeroGPU. "
"Two API endpoints: `/run/respond` (chat) and `/run/synth_batch` "
"(synthetic training pair batch — used by synth-puller cron).")
with gr.Tab("chat"):
chat_in = gr.Textbox(
lines=4,
placeholder="ask Surrogate-1: code, devops, security…")
chat_out = gr.Textbox(lines=20, label="response")
gr.Button("send", variant="primary").click(
respond, chat_in, chat_out, api_name="respond")
gr.Examples(
[["Write a Terraform module for AWS S3 with KMS encryption "
"+ versioning."],
["Implement Redis-based rate limit per-API-key in FastAPI."],
["Diagnose: Lambda cold-start 3s on 256MB. "
"Architecture options?"]],
inputs=chat_in)
with gr.Tab("synth_batch"):
gr.Markdown(
"Magpie-style: model generates instructions per domain, then "
"responds. Output is JSONL (one pair per line). Domains: "
+ ", ".join(sorted(DOMAIN_HINTS.keys())))
synth_dom = gr.Textbox(value="code-python", label="domain")
synth_cnt = gr.Number(value=12, precision=0, label="count (1-20)")
synth_out = gr.Textbox(lines=20, label="JSONL pairs")
gr.Button("generate", variant="primary").click(
synth_batch, [synth_dom, synth_cnt], synth_out,
api_name="synth_batch")
demo.queue(max_size=8).launch()