import os import re import html from threading import Thread import gradio as gr import spaces import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer, ) MODEL_ID = "OrionLLM/GRM-2.6-Opus" TITLE = "GRM-2.6-Opus" SUBTITLE = "Chat with GRM-2.6-Opus on ZeroGPU" DESCRIPTION = ( "Chat with GRM-2.6-Opus in a ZeroGPU Space, optimized with text-only chat, " "NF4 4-bit loading, bounded context, streaming output, and thinking parsing." ) PLACEHOLDER = ( "Ask GRM-2.6-Opus for code, debugging, planning, research, long-form reasoning, " "terminal-agent tasks, or complex multi-step workflows." ) MAX_INPUT_TOKENS = 16384 INTERNAL_MAX_NEW_TOKENS = 4096 HF_TOKEN = os.environ.get("HF_TOKEN") os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") torch.backends.cuda.matmul.allow_tf32 = True BNB_CONFIG = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16, ) tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, trust_remote_code=True, token=HF_TOKEN, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( MODEL_ID, trust_remote_code=True, token=HF_TOKEN, device_map={"": 0}, dtype=torch.bfloat16, quantization_config=BNB_CONFIG, attn_implementation="sdpa", low_cpu_mem_usage=True, ) model.eval() def model_input_device(): return next(model.parameters()).device def strip_thinking(text: str) -> str: if not text: return "" text = re.sub( r"(?is)]*>\s*.*?.*?", "", text, ) text = re.sub(r"(?is).*?", "", text) text = re.sub(r"(?is).*$", "", text) return text.strip() def render_thinking(raw_text: str) -> str: """ Converts model output like: reasoning here final answer here into a clean collapsible Thinking block in Gradio. Also handles incomplete streaming blocks. """ if not raw_text: return "" text = raw_text lower = text.lower() output_parts = [] pos = 0 while True: start = lower.find("", pos) if start == -1: answer = text[pos:] if answer: output_parts.append(answer) break before = text[pos:start] if before: output_parts.append(before) think_content_start = start + len("") end = lower.find("", think_content_start) if end == -1: thinking = text[think_content_start:] thinking = html.escape(thinking.strip()) output_parts.append( "\n\n
" "🧠 Thinking\n\n" f"
{thinking}
\n\n" "
\n\n" ) break thinking = text[think_content_start:end] thinking = html.escape(thinking.strip()) output_parts.append( "\n\n
" "🧠 Thinking\n\n" f"
{thinking}
\n\n" "
\n\n" ) pos = end + len("
") rendered = "".join(output_parts).strip() return rendered def build_messages(history, message): messages = [] trimmed_history = history[-8:] for user_text, assistant_text in trimmed_history: if user_text: messages.append( { "role": "user", "content": str(user_text).strip(), } ) if assistant_text: clean_answer = strip_thinking(str(assistant_text)) if clean_answer: messages.append( { "role": "assistant", "content": clean_answer, } ) messages.append( { "role": "user", "content": message.strip(), } ) return messages def estimate_duration( message, history, enable_thinking, preserve_thinking, temperature, top_p, top_k, repetition_penalty, ): del message, history, enable_thinking, preserve_thinking del temperature, top_p, top_k, repetition_penalty return 180 @spaces.GPU(duration=estimate_duration, size="large") def stream_chat( message: str, history: list, enable_thinking: bool, preserve_thinking: bool, temperature: float, top_p: float, top_k: int, repetition_penalty: float, ): if not message or not message.strip(): yield "" return messages = build_messages(history, message) rendered_prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=enable_thinking, preserve_thinking=preserve_thinking, ) inputs = tokenizer( rendered_prompt, return_tensors="pt", truncation=True, max_length=MAX_INPUT_TOKENS, ).to(model_input_device()) streamer = TextIteratorStreamer( tokenizer, timeout=120.0, skip_prompt=True, skip_special_tokens=True, ) generation_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=INTERNAL_MAX_NEW_TOKENS, do_sample=temperature > 0, temperature=max(temperature, 1e-5), top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, use_cache=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) worker = Thread(target=model.generate, kwargs=generation_kwargs) worker.start() raw_output = "" for chunk in streamer: raw_output += chunk yield render_thinking(raw_output) CSS = """ .gradio-container { max-width: 1180px !important; margin: 0 auto !important; } .title h1 { text-align: center; margin-bottom: 0.2rem !important; } .subtitle p, .meta p { text-align: center; } .meta p { font-size: 0.95rem; color: #6b7280; margin-top: 0.35rem !important; } .duplicate-button { margin: 0 auto 14px auto !important; } details { border: 1px solid #37415133; border-radius: 12px; padding: 0.75rem 1rem; margin: 0.5rem 0 1rem 0; background: rgba(127, 127, 127, 0.08); } summary { cursor: pointer; font-weight: 600; } pre { white-space: pre-wrap; word-break: break-word; margin: 0.75rem 0 0 0; } """ chatbot = gr.Chatbot( height=680, placeholder=PLACEHOLDER, sanitize_html=False, ) with gr.Blocks(css=CSS, theme="soft") as demo: gr.Markdown(f"# {TITLE}", elem_classes="title") gr.Markdown(SUBTITLE, elem_classes="subtitle") gr.Markdown( f"{DESCRIPTION} Model: [{MODEL_ID}](https://huggingface.co/{MODEL_ID})", elem_classes="meta", ) gr.DuplicateButton("Duplicate Space", elem_classes="duplicate-button") gr.ChatInterface( fn=stream_chat, chatbot=chatbot, fill_height=True, additional_inputs_accordion=gr.Accordion( "⚙️ Parameters", open=False, render=False, ), additional_inputs=[ gr.Checkbox( value=True, label="Enable thinking", render=False, ), gr.Checkbox( value=False, label="Preserve thinking across turns", render=False, ), gr.Slider( minimum=0.0, maximum=1.2, step=0.05, value=1.0, label="Temperature", render=False, ), gr.Slider( minimum=0.1, maximum=1.0, step=0.05, value=0.95, label="Top-p", render=False, ), gr.Slider( minimum=1, maximum=100, step=1, value=20, label="Top-k", render=False, ), gr.Slider( minimum=1.0, maximum=1.5, step=0.05, value=1.0, label="Repetition penalty", render=False, ), ], examples=[ ["Design a production-ready architecture for a local AI terminal-agent platform using GRM-2.6-Opus."], ["Write a detailed debugging plan for a flaky async Python test suite."], ["Build a responsive landing page in React and Tailwind for a premium AI coding product."], ["Create an agentic workflow plan for solving a Terminal-Bench style task from scratch."], ], cache_examples=False, ) if __name__ == "__main__": demo.launch()