File size: 8,478 Bytes
cca295a
fe83bcf
cca295a
 
 
2b68803
cca295a
 
 
 
 
 
 
 
 
fe83bcf
cca295a
 
 
 
 
 
 
fe83bcf
c0830e0
cca295a
 
 
67580ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cca295a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe83bcf
 
c0830e0
cca295a
 
 
 
 
 
 
 
 
c0830e0
 
 
 
cca295a
109d31d
2686e51
c0830e0
 
 
cca295a
 
c0830e0
 
 
cca295a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe83bcf
 
d45a2f7
c0830e0
0535836
c0830e0
cca295a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0830e0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""Surrogate-1 ZeroGPU — chat + synth-batch endpoints (gr.Blocks 4.44).

Two functions exposed via Gradio API:
  • POST /run/respond      — single chat completion (also UI tab)
  • POST /run/synth_batch  — Magpie-style synthetic training pair batch

synth_batch is hit by ~/.surrogate/bin/v2/synth-puller.sh every 5 min
on the bulk Space, drains free PRO ZeroGPU budget into training data.
Each call returns up to 20 JSONL pairs as a single string.

Earlier ChatInterface attempts hit a starlette TemplateResponse failure
during gradio's static-route init. gr.Blocks with explicit api_name on
each click avoids the same code path and exposes both endpoints cleanly.

Backbone: Qwen2.5-Coder-7B-Instruct + Surrogate-1 v1 LoRA, bnb int4.
"""
import json
import os

import gradio as gr
import spaces
import torch


BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-7B-Instruct")
LORA_REPO = os.environ.get("LORA_REPO", "axentx/surrogate-1-coder-7b-lora-v1")
HF_TOKEN = os.environ.get("HF_TOKEN", "")

SYSTEM = (
    "You are Surrogate-1 v1 (Qwen2.5-Coder-7B + LoRA), expert DevSecOps + "
    "SRE + coding agent.\n\n"
    "CRITICAL — knowledge cutoff: your training data ends Sept 2024. "
    "Anything launched AFTER Sept 2024 (cloud regions, framework versions, "
    "API changes, model releases) is OUTSIDE your knowledge. When the user "
    "asks about post-2024 facts, ALWAYS say:\n"
    "  'ผมไม่แน่ใจครับ — knowledge cutoff Sept 2024. โปรดเช็ค official docs.'\n"
    "Do NOT deny the existence of newer things — the user almost certainly "
    "knows more than you about post-Sept-2024 changes. If user says X "
    "exists and X sounds plausible (e.g. AWS region named ap-southeast-N, "
    "a new model release, a new framework version), TRUST THE USER, don't "
    "argue. Reply with what you know about adjacent context.\n\n"
    "Cite real APIs only. Say IDK rather than confabulate. When in doubt, "
    "ask the user to verify rather than asserting wrong info."
)

DOMAIN_HINTS = {
    "code-python":     "Python coding tasks, idiomatic, type-hinted",
    "code-typescript": "TypeScript / React / Node tasks, strict types",
    "code-rust":       "Rust ownership, async, performance",
    "code-go":         "Go concurrency, stdlib, microservices",
    "devops-tf":       "Terraform AWS/GCP modules, best practices",
    "devops-k8s":      "Kubernetes manifests, helm, troubleshooting",
    "devops-cdk":      "AWS CDK constructs, TypeScript",
    "ci-github":       "GitHub Actions workflows, reusable, secure",
    "sec-iam":         "IAM least-privilege policies, AssumeRole",
    "sec-cve":         "CVE remediation, SCA, dependency hygiene",
    "sre-runbook":     "Incident runbooks, on-call, postmortems",
    "sre-slo":         "SLO/SLI/error budgets, observability",
    "data-sql":        "SQL queries, indexes, query plans, optimisation",
    "ai-eng":          "RAG, vLLM, fine-tuning, evals",
    "api-rest":        "REST API design, OpenAPI, idempotency",
    "test-pytest":     "pytest fixtures, parametrize, markers",
}


_model = None
_tokenizer = None


def _load():
    global _model, _tokenizer
    if _model is not None:
        return _model, _tokenizer
    from transformers import (AutoModelForCausalLM, AutoTokenizer,
                              BitsAndBytesConfig)
    _tokenizer = AutoTokenizer.from_pretrained(
        BASE_MODEL, token=HF_TOKEN or None, trust_remote_code=True)
    if _tokenizer.pad_token_id is None:
        _tokenizer.pad_token_id = _tokenizer.eos_token_id
    bnb = BitsAndBytesConfig(load_in_4bit=True,
                             bnb_4bit_compute_dtype=torch.bfloat16,
                             bnb_4bit_quant_type="nf4",
                             bnb_4bit_use_double_quant=True)
    _model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL, token=HF_TOKEN or None, trust_remote_code=True,
        device_map="cuda", quantization_config=bnb)
    if LORA_REPO:
        try:
            from peft import PeftModel
            _model = PeftModel.from_pretrained(
                _model, LORA_REPO, token=HF_TOKEN or None)
            print(f"[ok] LoRA: {LORA_REPO}")
        except Exception as e:
            print(f"[skip] LoRA: {e}")
    return _model, _tokenizer


def _generate(prompt: str, max_tokens: int = 768,
              temperature: float = 0.7) -> str:
    model, tokenizer = _load()
    msgs = [{"role": "system", "content": SYSTEM},
            {"role": "user", "content": prompt}]
    chat = tokenizer.apply_chat_template(
        msgs, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(chat, return_tensors="pt", truncation=True,
                       max_length=8000).to("cuda")
    out = model.generate(
        **inputs,
        max_new_tokens=max_tokens, temperature=temperature, do_sample=True,
        top_p=0.9, pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id)
    return tokenizer.decode(
        out[0][inputs["input_ids"].shape[1]:],
        skip_special_tokens=True).strip()


@spaces.GPU(duration=300)
def respond(message: str) -> str:
    if not message or not message.strip():
        return "(empty)"
    return _generate(message, max_tokens=768, temperature=0.4)


@spaces.GPU(duration=600)
def synth_batch(domain: str, count) -> str:
    """Magpie-style synthetic pair generation. Returns N JSONL lines."""
    domain = (domain or "code-python").strip()
    try:
        count = int(count or 12)
    except (TypeError, ValueError):
        count = 12
    count = max(1, min(20, count))
    hint = DOMAIN_HINTS.get(domain, domain)

    seed = (f"Generate ONE realistic technical question a senior engineer "
            f"would ask about {hint}. Output JUST the question text, no "
            f"preamble or quotes. Make it specific and answerable in "
            f"200-500 words with code/config examples.")

    pairs = []
    for _ in range(count):
        try:
            instruction = _generate(seed, max_tokens=200, temperature=0.95)
            instruction = (instruction.split("\n")[0]
                           .strip().strip('"').strip("'")[:600])
            if len(instruction) < 30:
                continue
            response = _generate(instruction, max_tokens=900,
                                 temperature=0.4)
            if len(response) < 80:
                continue
            pairs.append(json.dumps({
                "prompt": instruction,
                "response": response,
                "source": f"surrogate-1-zero-gpu/synth-{domain}",
                "meta": {"domain": domain, "magpie": True},
            }, ensure_ascii=False))
        except Exception as e:
            print(f"[synth_batch] err: {e}")
            continue
    return "\n".join(pairs)


with gr.Blocks(title="Surrogate-1 ZeroGPU") as demo:
    gr.Markdown("# Surrogate-1 (7B + v1 LoRA, ZeroGPU A10G)")
    gr.Markdown(
        "Qwen2.5-Coder-7B + Surrogate-1 v1 LoRA on free PRO ZeroGPU. "
        "Two API endpoints: `/run/respond` (chat) and `/run/synth_batch` "
        "(synthetic training pair batch — used by synth-puller cron).")

    with gr.Tab("chat"):
        chat_in = gr.Textbox(
            lines=4,
            placeholder="ask Surrogate-1: code, devops, security…")
        chat_out = gr.Textbox(lines=20, label="response")
        gr.Button("send", variant="primary").click(
            respond, chat_in, chat_out, api_name="respond")
        gr.Examples(
            [["Write a Terraform module for AWS S3 with KMS encryption "
              "+ versioning."],
             ["Implement Redis-based rate limit per-API-key in FastAPI."],
             ["Diagnose: Lambda cold-start 3s on 256MB. "
              "Architecture options?"]],
            inputs=chat_in)

    with gr.Tab("synth_batch"):
        gr.Markdown(
            "Magpie-style: model generates instructions per domain, then "
            "responds. Output is JSONL (one pair per line). Domains: "
            + ", ".join(sorted(DOMAIN_HINTS.keys())))
        synth_dom = gr.Textbox(value="code-python", label="domain")
        synth_cnt = gr.Number(value=12, precision=0, label="count (1-20)")
        synth_out = gr.Textbox(lines=20, label="JSONL pairs")
        gr.Button("generate", variant="primary").click(
            synth_batch, [synth_dom, synth_cnt], synth_out,
            api_name="synth_batch")


demo.queue(max_size=8).launch()