gemma-4-31B-it / app.py
aidn's picture
Update app.py
cd7f07b verified
import gradio as gr
import torch
import spaces
from transformers import AutoProcessor, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
import os
import re
MODEL_ID = "google/gemma-4-31B-it"
processor = AutoProcessor.from_pretrained(MODEL_ID, token=os.environ.get("HF_TOKEN"))
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
token=os.environ.get("HF_TOKEN")
)
@spaces.GPU
def generate(message, history):
messages = []
for msg in history:
messages.append({"role": msg["role"], "content": msg["content"]})
messages.append({"role": "user", "content": message})
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=True
)
inputs = processor(text=text, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=False)
thread = Thread(target=model.generate, kwargs=dict(
**inputs,
max_new_tokens=32000,
temperature=0.5,
top_p=0.9,
do_sample=True,
streamer=streamer
))
thread.start()
raw = ""
for new_text in streamer:
raw += new_text
# Solange wir noch im Thinking-Block sind, zeige "Thinking..." an
if "<|channel>thought" in raw and "<channel|>" not in raw:
yield "🤔 *Thinking...*"
continue
# Sobald <channel|> da ist, nur den Teil danach streamen
if "<channel|>" in raw:
answer = raw.split("<channel|>", 1)[1].strip()
else:
answer = raw.strip()
# Trailing EOS/Turn-Token entfernen
answer = re.sub(r"<turn\|>\s*$", "", answer).strip()
answer = re.sub(r"<end_of_turn>\s*$", "", answer).strip()
if answer:
yield answer
demo = gr.ChatInterface(fn=generate, title="Gemma 4 – 31B")
if __name__ == "__main__":
demo.launch()