import os import json import torch import threading from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from gradio import Server from fastapi.responses import HTMLResponse import spaces # ── Model Setup ────────────────────────────────────────────────────────────── MODEL_ID = "Zyphra/ZAYA1-8B" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) # ── Gradio Server ──────────────────────────────────────────────────────────── app = Server() # ── Streaming API endpoint (generator yields cumulative text) ──────────────── @app.api() @spaces.GPU(duration=120) def generate( message: str, history: str = "[]", system_prompt: str = "You are ZAYA1-8B, a highly capable reasoning assistant built by Zyphra. You excel at detailed long-form reasoning, mathematics, and coding. Think step by step when solving complex problems.", temperature: float = 1.0, top_p: float = 0.95, max_new_tokens: int = 2048, ) -> str: """Stream a response from ZAYA1-8B token by token.""" hist = json.loads(history) if isinstance(history, str) else history messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) for turn in hist: messages.append({"role": turn["role"], "content": turn["content"]}) messages.append({"role": "user", "content": message}) input_ids = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ).to(model.device) streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) thread = threading.Thread( target=model.generate, kwargs=dict( input_ids=input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=50, do_sample=True, streamer=streamer, ), ) thread.start() full_text = "" for new_text in streamer: full_text += new_text yield full_text thread.join() # ── Serve Frontend ──────────────────────────────────────────────────────────── @app.get("/") async def homepage(): html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") with open(html_path, "r", encoding="utf-8") as f: return HTMLResponse(content=f.read()) app.launch(show_error=True)