ZAYA1-8B / app.py
akhaliq's picture
akhaliq HF Staff
feat: implement streaming token generation and adjust max token constraints
5894cc6
import os
import json
import torch
import threading
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from gradio import Server
from fastapi.responses import HTMLResponse
import spaces
# ── Model Setup ──────────────────────────────────────────────────────────────
MODEL_ID = "Zyphra/ZAYA1-8B"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
# ── Gradio Server ────────────────────────────────────────────────────────────
app = Server()
# ── Streaming API endpoint (generator yields cumulative text) ────────────────
@app.api()
@spaces.GPU(duration=120)
def generate(
message: str,
history: str = "[]",
system_prompt: str = "You are ZAYA1-8B, a highly capable reasoning assistant built by Zyphra. You excel at detailed long-form reasoning, mathematics, and coding. Think step by step when solving complex problems.",
temperature: float = 1.0,
top_p: float = 0.95,
max_new_tokens: int = 2048,
) -> str:
"""Stream a response from ZAYA1-8B token by token."""
hist = json.loads(history) if isinstance(history, str) else history
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
for turn in hist:
messages.append({"role": turn["role"], "content": turn["content"]})
messages.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
).to(model.device)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
thread = threading.Thread(
target=model.generate,
kwargs=dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=50,
do_sample=True,
streamer=streamer,
),
)
thread.start()
full_text = ""
for new_text in streamer:
full_text += new_text
yield full_text
thread.join()
# ── Serve Frontend ────────────────────────────────────────────────────────────
@app.get("/")
async def homepage():
html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")
with open(html_path, "r", encoding="utf-8") as f:
return HTMLResponse(content=f.read())
app.launch(show_error=True)