| import os |
| import json |
| import torch |
| import threading |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
| from gradio import Server |
| from fastapi.responses import HTMLResponse |
| import spaces |
|
|
| |
|
|
| MODEL_ID = "Zyphra/ZAYA1-8B" |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
|
|
| |
|
|
| app = Server() |
|
|
|
|
| |
|
|
| @app.api() |
| @spaces.GPU(duration=120) |
| def generate( |
| message: str, |
| history: str = "[]", |
| system_prompt: str = "You are ZAYA1-8B, a highly capable reasoning assistant built by Zyphra. You excel at detailed long-form reasoning, mathematics, and coding. Think step by step when solving complex problems.", |
| temperature: float = 1.0, |
| top_p: float = 0.95, |
| max_new_tokens: int = 2048, |
| ) -> str: |
| """Stream a response from ZAYA1-8B token by token.""" |
| hist = json.loads(history) if isinstance(history, str) else history |
|
|
| messages = [] |
| if system_prompt: |
| messages.append({"role": "system", "content": system_prompt}) |
| for turn in hist: |
| messages.append({"role": turn["role"], "content": turn["content"]}) |
| messages.append({"role": "user", "content": message}) |
|
|
| input_ids = tokenizer.apply_chat_template( |
| messages, add_generation_prompt=True, return_tensors="pt" |
| ).to(model.device) |
|
|
| streamer = TextIteratorStreamer( |
| tokenizer, skip_prompt=True, skip_special_tokens=True |
| ) |
|
|
| thread = threading.Thread( |
| target=model.generate, |
| kwargs=dict( |
| input_ids=input_ids, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=50, |
| do_sample=True, |
| streamer=streamer, |
| ), |
| ) |
| thread.start() |
|
|
| full_text = "" |
| for new_text in streamer: |
| full_text += new_text |
| yield full_text |
|
|
| thread.join() |
|
|
|
|
| |
|
|
| @app.get("/") |
| async def homepage(): |
| html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") |
| with open(html_path, "r", encoding="utf-8") as f: |
| return HTMLResponse(content=f.read()) |
|
|
|
|
| app.launch(show_error=True) |
|
|