import gradio as gr import torch import spaces from transformers import AutoProcessor, AutoModelForCausalLM, TextIteratorStreamer from threading import Thread import os import re MODEL_ID = "google/gemma-4-31B-it" processor = AutoProcessor.from_pretrained(MODEL_ID, token=os.environ.get("HF_TOKEN")) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", token=os.environ.get("HF_TOKEN") ) @spaces.GPU def generate(message, history): messages = [] for msg in history: messages.append({"role": msg["role"], "content": msg["content"]}) messages.append({"role": "user", "content": message}) text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=True ) inputs = processor(text=text, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=False) thread = Thread(target=model.generate, kwargs=dict( **inputs, max_new_tokens=32000, temperature=0.5, top_p=0.9, do_sample=True, streamer=streamer )) thread.start() raw = "" for new_text in streamer: raw += new_text # Solange wir noch im Thinking-Block sind, zeige "Thinking..." an if "<|channel>thought" in raw and "" not in raw: yield "🤔 *Thinking...*" continue # Sobald da ist, nur den Teil danach streamen if "" in raw: answer = raw.split("", 1)[1].strip() else: answer = raw.strip() # Trailing EOS/Turn-Token entfernen answer = re.sub(r"\s*$", "", answer).strip() answer = re.sub(r"\s*$", "", answer).strip() if answer: yield answer demo = gr.ChatInterface(fn=generate, title="Gemma 4 – 31B") if __name__ == "__main__": demo.launch()