File size: 2,991 Bytes
26b79c3
5894cc6
26b79c3
5894cc6
 
26b79c3
a58f829
26b79c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5894cc6
a58f829
 
 
 
26b79c3
a58f829
26b79c3
 
 
5894cc6
26b79c3
5894cc6
a58f829
26b79c3
 
 
 
a58f829
26b79c3
 
 
 
 
 
 
5894cc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26b79c3
5894cc6
 
 
 
 
 
26b79c3
5894cc6
26b79c3
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import json
import torch
import threading
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from gradio import Server
from fastapi.responses import HTMLResponse
import spaces

# ── Model Setup ──────────────────────────────────────────────────────────────

MODEL_ID = "Zyphra/ZAYA1-8B"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

# ── Gradio Server ────────────────────────────────────────────────────────────

app = Server()


# ── Streaming API endpoint (generator yields cumulative text) ────────────────

@app.api()
@spaces.GPU(duration=120)
def generate(
    message: str,
    history: str = "[]",
    system_prompt: str = "You are ZAYA1-8B, a highly capable reasoning assistant built by Zyphra. You excel at detailed long-form reasoning, mathematics, and coding. Think step by step when solving complex problems.",
    temperature: float = 1.0,
    top_p: float = 0.95,
    max_new_tokens: int = 2048,
) -> str:
    """Stream a response from ZAYA1-8B token by token."""
    hist = json.loads(history) if isinstance(history, str) else history

    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    for turn in hist:
        messages.append({"role": turn["role"], "content": turn["content"]})
    messages.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, return_tensors="pt"
    ).to(model.device)

    streamer = TextIteratorStreamer(
        tokenizer, skip_prompt=True, skip_special_tokens=True
    )

    thread = threading.Thread(
        target=model.generate,
        kwargs=dict(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=50,
            do_sample=True,
            streamer=streamer,
        ),
    )
    thread.start()

    full_text = ""
    for new_text in streamer:
        full_text += new_text
        yield full_text

    thread.join()


# ── Serve Frontend ────────────────────────────────────────────────────────────

@app.get("/")
async def homepage():
    html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")
    with open(html_path, "r", encoding="utf-8") as f:
        return HTMLResponse(content=f.read())


app.launch(show_error=True)