File size: 2,940 Bytes
cef92a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""axentx coder-zero-gpu-1 — Qwen2.5-Coder-32B-Instruct-AWQ on ZeroGPU.

Exposes OpenAI-compatible /v1/chat/completions so the axentx pipeline's
LLM chain can hit it like any other upstream provider.
"""
import os
import time
import spaces
import torch
import gradio as gr
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_ID = os.environ.get("MODEL_ID", "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct")

print(f"[init] loading tokenizer: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
print(f"[init] loading model")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
    trust_remote_code=True,
)
print(f"[init] ready")


@spaces.GPU(duration=120)
def _generate(messages, max_tokens=1024, temperature=0.3):
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    out = model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        temperature=max(temperature, 0.01),
        do_sample=temperature > 0,
        pad_token_id=tokenizer.eos_token_id,
    )
    text = tokenizer.decode(
        out[0][inputs.input_ids.shape[1]:],
        skip_special_tokens=True,
    )
    return text


app = FastAPI(title="axentx coder ZeroGPU")
app.add_middleware(
    CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]
)


class ChatRequest(BaseModel):
    messages: list
    max_tokens: int = 1024
    temperature: float = 0.3
    model: str = "axentx-coder-2"


@app.post("/v1/chat/completions")
def chat_completions(req: ChatRequest):
    t0 = time.time()
    text = _generate(req.messages, req.max_tokens, req.temperature)
    return {
        "id": f"axentx-{int(t0)}",
        "object": "chat.completion",
        "created": int(t0),
        "model": req.model,
        "choices": [{
            "index": 0,
            "message": {"role": "assistant", "content": text},
            "finish_reason": "stop",
        }],
        "usage": {
            "prompt_tokens": 0,
            "completion_tokens": len(text.split()),
            "total_tokens": len(text.split()),
        },
    }


@app.get("/health")
def health():
    return {"status": "ok", "model": MODEL_ID}


def _ui_chat(message, history):
    msgs = []
    for h in history:
        if h.get("role") and h.get("content"):
            msgs.append({"role": h["role"], "content": h["content"]})
    msgs.append({"role": "user", "content": message})
    return _generate(msgs, max_tokens=1024, temperature=0.3)


demo = gr.ChatInterface(
    _ui_chat,
    title="axentx Coder — Qwen2.5-Coder-32B-Instruct (ZeroGPU)",
    type="messages",
)

app = gr.mount_gradio_app(app, demo, path="/")