File size: 7,589 Bytes
87b89b2
6d6a848
87b89b2
 
 
6d6a848
87b89b2
 
 
 
 
 
 
 
 
 
 
 
6d6a848
87b89b2
 
 
99cf609
 
87b89b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99cf609
 
 
87b89b2
 
 
 
 
 
 
 
 
99cf609
 
 
 
87b89b2
f091fe3
548096f
87b89b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f091fe3
 
 
87b89b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""Surrogate-1 ZeroGPU — chat + synth-batch endpoints (gr.Blocks 4.44).

Two functions exposed via Gradio API:
  • POST /run/respond      — single chat completion (also UI tab)
  • POST /run/synth_batch  — Magpie-style synthetic training pair batch

synth_batch is hit by ~/.surrogate/bin/v2/synth-puller.sh every 5 min
on the bulk Space, drains free PRO ZeroGPU budget into training data.
Each call returns up to 20 JSONL pairs as a single string.

Earlier ChatInterface attempts hit a starlette TemplateResponse failure
during gradio's static-route init. gr.Blocks with explicit api_name on
each click avoids the same code path and exposes both endpoints cleanly.

Backbone: Qwen2.5-Coder-7B-Instruct + Surrogate-1 v1 LoRA, bnb int4.
"""
import json
import os

import gradio as gr
import spaces
import torch


BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-7B-Instruct")
LORA_REPO = os.environ.get("LORA_REPO", "axentx/surrogate-1-coder-7b-lora-v1")
HF_TOKEN = os.environ.get("HF_TOKEN", "")

SYSTEM = ("You are Surrogate-1, expert DevSecOps + SRE + coding agent. "
          "Cite real APIs only. Say IDK rather than confabulate.")

DOMAIN_HINTS = {
    "code-python":     "Python coding tasks, idiomatic, type-hinted",
    "code-typescript": "TypeScript / React / Node tasks, strict types",
    "code-rust":       "Rust ownership, async, performance",
    "code-go":         "Go concurrency, stdlib, microservices",
    "devops-tf":       "Terraform AWS/GCP modules, best practices",
    "devops-k8s":      "Kubernetes manifests, helm, troubleshooting",
    "devops-cdk":      "AWS CDK constructs, TypeScript",
    "ci-github":       "GitHub Actions workflows, reusable, secure",
    "sec-iam":         "IAM least-privilege policies, AssumeRole",
    "sec-cve":         "CVE remediation, SCA, dependency hygiene",
    "sre-runbook":     "Incident runbooks, on-call, postmortems",
    "sre-slo":         "SLO/SLI/error budgets, observability",
    "data-sql":        "SQL queries, indexes, query plans, optimisation",
    "ai-eng":          "RAG, vLLM, fine-tuning, evals",
    "api-rest":        "REST API design, OpenAPI, idempotency",
    "test-pytest":     "pytest fixtures, parametrize, markers",
}


_model = None
_tokenizer = None


def _load():
    global _model, _tokenizer
    if _model is not None:
        return _model, _tokenizer
    from transformers import (AutoModelForCausalLM, AutoTokenizer,
                              BitsAndBytesConfig)
    _tokenizer = AutoTokenizer.from_pretrained(
        BASE_MODEL, token=HF_TOKEN or None, trust_remote_code=True)
    if _tokenizer.pad_token_id is None:
        _tokenizer.pad_token_id = _tokenizer.eos_token_id
    bnb = BitsAndBytesConfig(load_in_4bit=True,
                             bnb_4bit_compute_dtype=torch.bfloat16,
                             bnb_4bit_quant_type="nf4",
                             bnb_4bit_use_double_quant=True)
    _model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL, token=HF_TOKEN or None, trust_remote_code=True,
        device_map="cuda", quantization_config=bnb)
    if LORA_REPO:
        try:
            from peft import PeftModel
            _model = PeftModel.from_pretrained(
                _model, LORA_REPO, token=HF_TOKEN or None)
            print(f"[ok] LoRA: {LORA_REPO}")
        except Exception as e:
            print(f"[skip] LoRA: {e}")
    return _model, _tokenizer


def _generate(prompt: str, max_tokens: int = 768,
              temperature: float = 0.7) -> str:
    model, tokenizer = _load()
    msgs = [{"role": "system", "content": SYSTEM},
            {"role": "user", "content": prompt}]
    chat = tokenizer.apply_chat_template(
        msgs, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(chat, return_tensors="pt", truncation=True,
                       max_length=8000).to("cuda")
    out = model.generate(
        **inputs,
        max_new_tokens=max_tokens, temperature=temperature, do_sample=True,
        top_p=0.9, pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id)
    return tokenizer.decode(
        out[0][inputs["input_ids"].shape[1]:],
        skip_special_tokens=True).strip()


@spaces.GPU(duration=300)
def respond(message: str) -> str:
    if not message or not message.strip():
        return "(empty)"
    return _generate(message, max_tokens=768, temperature=0.4)


@spaces.GPU(duration=600)
def synth_batch(domain: str, count) -> str:
    """Magpie-style synthetic pair generation. Returns N JSONL lines."""
    domain = (domain or "code-python").strip()
    try:
        count = int(count or 12)
    except (TypeError, ValueError):
        count = 12
    count = max(1, min(20, count))
    hint = DOMAIN_HINTS.get(domain, domain)

    seed = (f"Generate ONE realistic technical question a senior engineer "
            f"would ask about {hint}. Output JUST the question text, no "
            f"preamble or quotes. Make it specific and answerable in "
            f"200-500 words with code/config examples.")

    pairs = []
    for _ in range(count):
        try:
            instruction = _generate(seed, max_tokens=200, temperature=0.95)
            instruction = (instruction.split("\n")[0]
                           .strip().strip('"').strip("'")[:600])
            if len(instruction) < 30:
                continue
            response = _generate(instruction, max_tokens=900,
                                 temperature=0.4)
            if len(response) < 80:
                continue
            pairs.append(json.dumps({
                "prompt": instruction,
                "response": response,
                "source": f"surrogate-1-zero-gpu/synth-{domain}",
                "meta": {"domain": domain, "magpie": True},
            }, ensure_ascii=False))
        except Exception as e:
            print(f"[synth_batch] err: {e}")
            continue
    return "\n".join(pairs)


with gr.Blocks(title="Surrogate-1 ZeroGPU") as demo:
    gr.Markdown("# Surrogate-1 (7B + v1 LoRA, ZeroGPU A10G)")
    gr.Markdown(
        "Qwen2.5-Coder-7B + Surrogate-1 v1 LoRA on free PRO ZeroGPU. "
        "Two API endpoints: `/run/respond` (chat) and `/run/synth_batch` "
        "(synthetic training pair batch — used by synth-puller cron).")

    with gr.Tab("chat"):
        chat_in = gr.Textbox(
            lines=4,
            placeholder="ask Surrogate-1: code, devops, security…")
        chat_out = gr.Textbox(lines=20, label="response")
        gr.Button("send", variant="primary").click(
            respond, chat_in, chat_out, api_name="respond")
        gr.Examples(
            [["Write a Terraform module for AWS S3 with KMS encryption "
              "+ versioning."],
             ["Implement Redis-based rate limit per-API-key in FastAPI."],
             ["Diagnose: Lambda cold-start 3s on 256MB. "
              "Architecture options?"]],
            inputs=chat_in)

    with gr.Tab("synth_batch"):
        gr.Markdown(
            "Magpie-style: model generates instructions per domain, then "
            "responds. Output is JSONL (one pair per line). Domains: "
            + ", ".join(sorted(DOMAIN_HINTS.keys())))
        synth_dom = gr.Textbox(value="code-python", label="domain")
        synth_cnt = gr.Number(value=12, precision=0, label="count (1-20)")
        synth_out = gr.Textbox(lines=20, label="JSONL pairs")
        gr.Button("generate", variant="primary").click(
            synth_batch, [synth_dom, synth_cnt], synth_out,
            api_name="synth_batch")


demo.queue(max_size=8).launch()