Spaces:
Running on Zero
Running on Zero
| """Surrogate-1 ZeroGPU — chat + synth-batch endpoints (gr.Blocks 4.44). | |
| Two functions exposed via Gradio API: | |
| • POST /run/respond — single chat completion (also UI tab) | |
| • POST /run/synth_batch — Magpie-style synthetic training pair batch | |
| synth_batch is hit by ~/.surrogate/bin/v2/synth-puller.sh every 5 min | |
| on the bulk Space, drains free PRO ZeroGPU budget into training data. | |
| Each call returns up to 20 JSONL pairs as a single string. | |
| Earlier ChatInterface attempts hit a starlette TemplateResponse failure | |
| during gradio's static-route init. gr.Blocks with explicit api_name on | |
| each click avoids the same code path and exposes both endpoints cleanly. | |
| Backbone: Qwen2.5-Coder-7B-Instruct + Surrogate-1 v1 LoRA, bnb int4. | |
| """ | |
| import json | |
| import os | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-7B-Instruct") | |
| LORA_REPO = os.environ.get("LORA_REPO", "axentx/surrogate-1-coder-7b-lora-v1") | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| SYSTEM = ("You are Surrogate-1, expert DevSecOps + SRE + coding agent. " | |
| "Cite real APIs only. Say IDK rather than confabulate.") | |
| DOMAIN_HINTS = { | |
| "code-python": "Python coding tasks, idiomatic, type-hinted", | |
| "code-typescript": "TypeScript / React / Node tasks, strict types", | |
| "code-rust": "Rust ownership, async, performance", | |
| "code-go": "Go concurrency, stdlib, microservices", | |
| "devops-tf": "Terraform AWS/GCP modules, best practices", | |
| "devops-k8s": "Kubernetes manifests, helm, troubleshooting", | |
| "devops-cdk": "AWS CDK constructs, TypeScript", | |
| "ci-github": "GitHub Actions workflows, reusable, secure", | |
| "sec-iam": "IAM least-privilege policies, AssumeRole", | |
| "sec-cve": "CVE remediation, SCA, dependency hygiene", | |
| "sre-runbook": "Incident runbooks, on-call, postmortems", | |
| "sre-slo": "SLO/SLI/error budgets, observability", | |
| "data-sql": "SQL queries, indexes, query plans, optimisation", | |
| "ai-eng": "RAG, vLLM, fine-tuning, evals", | |
| "api-rest": "REST API design, OpenAPI, idempotency", | |
| "test-pytest": "pytest fixtures, parametrize, markers", | |
| } | |
| _model = None | |
| _tokenizer = None | |
| def _load(): | |
| global _model, _tokenizer | |
| if _model is not None: | |
| return _model, _tokenizer | |
| from transformers import (AutoModelForCausalLM, AutoTokenizer, | |
| BitsAndBytesConfig) | |
| _tokenizer = AutoTokenizer.from_pretrained( | |
| BASE_MODEL, token=HF_TOKEN or None, trust_remote_code=True) | |
| if _tokenizer.pad_token_id is None: | |
| _tokenizer.pad_token_id = _tokenizer.eos_token_id | |
| bnb = BitsAndBytesConfig(load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True) | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, token=HF_TOKEN or None, trust_remote_code=True, | |
| device_map="cuda", quantization_config=bnb) | |
| if LORA_REPO: | |
| try: | |
| from peft import PeftModel | |
| _model = PeftModel.from_pretrained( | |
| _model, LORA_REPO, token=HF_TOKEN or None) | |
| print(f"[ok] LoRA: {LORA_REPO}") | |
| except Exception as e: | |
| print(f"[skip] LoRA: {e}") | |
| return _model, _tokenizer | |
| def _generate(prompt: str, max_tokens: int = 768, | |
| temperature: float = 0.7) -> str: | |
| model, tokenizer = _load() | |
| msgs = [{"role": "system", "content": SYSTEM}, | |
| {"role": "user", "content": prompt}] | |
| chat = tokenizer.apply_chat_template( | |
| msgs, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(chat, return_tensors="pt", truncation=True, | |
| max_length=8000).to("cuda") | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, temperature=temperature, do_sample=True, | |
| top_p=0.9, pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id) | |
| return tokenizer.decode( | |
| out[0][inputs["input_ids"].shape[1]:], | |
| skip_special_tokens=True).strip() | |
| def respond(message: str) -> str: | |
| if not message or not message.strip(): | |
| return "(empty)" | |
| return _generate(message, max_tokens=768, temperature=0.4) | |
| def synth_batch(domain: str, count) -> str: | |
| """Magpie-style synthetic pair generation. Returns N JSONL lines.""" | |
| domain = (domain or "code-python").strip() | |
| try: | |
| count = int(count or 12) | |
| except (TypeError, ValueError): | |
| count = 12 | |
| count = max(1, min(20, count)) | |
| hint = DOMAIN_HINTS.get(domain, domain) | |
| seed = (f"Generate ONE realistic technical question a senior engineer " | |
| f"would ask about {hint}. Output JUST the question text, no " | |
| f"preamble or quotes. Make it specific and answerable in " | |
| f"200-500 words with code/config examples.") | |
| pairs = [] | |
| for _ in range(count): | |
| try: | |
| instruction = _generate(seed, max_tokens=200, temperature=0.95) | |
| instruction = (instruction.split("\n")[0] | |
| .strip().strip('"').strip("'")[:600]) | |
| if len(instruction) < 30: | |
| continue | |
| response = _generate(instruction, max_tokens=900, | |
| temperature=0.4) | |
| if len(response) < 80: | |
| continue | |
| pairs.append(json.dumps({ | |
| "prompt": instruction, | |
| "response": response, | |
| "source": f"surrogate-1-zero-gpu/synth-{domain}", | |
| "meta": {"domain": domain, "magpie": True}, | |
| }, ensure_ascii=False)) | |
| except Exception as e: | |
| print(f"[synth_batch] err: {e}") | |
| continue | |
| return "\n".join(pairs) | |
| with gr.Blocks(title="Surrogate-1 ZeroGPU") as demo: | |
| gr.Markdown("# Surrogate-1 (7B + v1 LoRA, ZeroGPU A10G)") | |
| gr.Markdown( | |
| "Qwen2.5-Coder-7B + Surrogate-1 v1 LoRA on free PRO ZeroGPU. " | |
| "Two API endpoints: `/run/respond` (chat) and `/run/synth_batch` " | |
| "(synthetic training pair batch — used by synth-puller cron).") | |
| with gr.Tab("chat"): | |
| chat_in = gr.Textbox( | |
| lines=4, | |
| placeholder="ask Surrogate-1: code, devops, security…") | |
| chat_out = gr.Textbox(lines=20, label="response") | |
| gr.Button("send", variant="primary").click( | |
| respond, chat_in, chat_out, api_name="respond") | |
| gr.Examples( | |
| [["Write a Terraform module for AWS S3 with KMS encryption " | |
| "+ versioning."], | |
| ["Implement Redis-based rate limit per-API-key in FastAPI."], | |
| ["Diagnose: Lambda cold-start 3s on 256MB. " | |
| "Architecture options?"]], | |
| inputs=chat_in) | |
| with gr.Tab("synth_batch"): | |
| gr.Markdown( | |
| "Magpie-style: model generates instructions per domain, then " | |
| "responds. Output is JSONL (one pair per line). Domains: " | |
| + ", ".join(sorted(DOMAIN_HINTS.keys()))) | |
| synth_dom = gr.Textbox(value="code-python", label="domain") | |
| synth_cnt = gr.Number(value=12, precision=0, label="count (1-20)") | |
| synth_out = gr.Textbox(lines=20, label="JSONL pairs") | |
| gr.Button("generate", variant="primary").click( | |
| synth_batch, [synth_dom, synth_cnt], synth_out, | |
| api_name="synth_batch") | |
| demo.queue(max_size=8).launch() | |