#!/usr/bin/env python3 # SPDX-License-Identifier: Apache-2.0 """Granite Switch 4.1 3B Playground — Hugging Face Space. Each adapter has a specific prompt protocol. This app provides structured input forms per adapter so the control tokens AND prompt formats are correct. """ import json import os import time import urllib.error import urllib.parse import urllib.request import spaces import torch import granite_switch.hf # noqa: F401 — registers HF backend import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer MODEL_ID = "ibm-granite/granite-switch-4.1-3b-preview" MODEL_OPTIONS = { "granite-4.0-micro": "ibm-granite/granite-4.0-micro", "granite-4.1-8b": "ibm-granite/granite-4.1-8b", "granite-4.1-30b": "ibm-granite/granite-4.1-30b", "qwen2.5-coder-32b": "Qwen/Qwen2.5-Coder-32B-Instruct", } USE_EXTERNAL_LLM = os.getenv("USE_EXTERNAL_LLM", "").lower() in { "1", "true", "yes", "on", } EXTERNAL_LLM_MODEL = os.getenv("EXTERNAL_LLM_MODEL", "granite-4.0-micro") EXTERNAL_LLM_GPUS = int(os.getenv("EXTERNAL_LLM_GPUS", "1")) EXTERNAL_LLM_BROKER_URL = os.getenv("EXTERNAL_LLM_BROKER_URL") or os.getenv("BROKER_URL") EXTERNAL_LLM_TIMEOUT_SECONDS = int(os.getenv("EXTERNAL_LLM_TIMEOUT_SECONDS", "600")) EXTERNAL_LLM_POLL_INTERVAL_SECONDS = float( os.getenv("EXTERNAL_LLM_POLL_INTERVAL_SECONDS", "2") ) JUDGE_SYSTEM = ( "As a judge agent, your role is to help assess whether the provided text " "meets the given judging criteria, utilizing all available information, " "including conversations, documents, and tools." ) GUARDIAN_CRITERIA_BANK = { "harm": ( "Texts deemed universally harmful are those that pose harm across all " "contexts and audiences." ), "social_bias": ( "Social bias refers to prejudice or discrimination against individuals " "or groups based on social characteristics such as race, gender, " "religion, or socioeconomic status." ), "jailbreak": ( "Jailbreaking: the deliberate evasion of an AI system's built-in " "safeguards or ethical boundaries, often by crafting specific prompts " "or scenarios intended to manipulate the AI into producing restricted " "or inappropriate content." ), "profanity": "The text contains profane or vulgar language.", "violence": "The text contains descriptions of or encouragement of violence.", "unethical_behavior": "The text promotes unethical behavior or actions.", } tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = None def _get_model(): global model if model is None: model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16 ) model.eval() model.to("cuda") return model def validate_query_llm_args(model_name, gpus, user_text): if model_name not in MODEL_OPTIONS: raise ValueError(f"Model is not allowed: {model_name}") if gpus < 1 or gpus > 16: raise ValueError("GPUs must be between 1 and 16") if user_text is None or not user_text.strip(): raise ValueError("Prompt cannot be empty") if len(user_text) > 10_000: raise ValueError("Prompt is too long; max 10,000 characters") def _broker_request(path, data=None, method="GET"): if not EXTERNAL_LLM_BROKER_URL: raise RuntimeError( "USE_EXTERNAL_LLM is set, but EXTERNAL_LLM_BROKER_URL or BROKER_URL " "is missing." ) broker_token = os.getenv("BROKER_TOKEN") if not broker_token: raise RuntimeError("USE_EXTERNAL_LLM is set, but BROKER_TOKEN is missing.") url = f"{EXTERNAL_LLM_BROKER_URL.rstrip('/')}/{path.lstrip('/')}" encoded_data = None headers = {"X-Broker-Token": broker_token} if data is not None: encoded_data = urllib.parse.urlencode(data).encode("utf-8") headers["Content-Type"] = "application/x-www-form-urlencoded" request = urllib.request.Request( url, data=encoded_data, headers=headers, method=method ) try: with urllib.request.urlopen(request, timeout=60) as response: return json.loads(response.read().decode("utf-8")) except urllib.error.HTTPError as exc: detail = exc.read().decode("utf-8", errors="replace") raise RuntimeError(f"Broker returned HTTP {exc.code}: {detail}") from exc except urllib.error.URLError as exc: raise RuntimeError(f"Could not connect to broker: {exc.reason}") from exc def query_llm(user_text, max_new_tokens=128): """Submit a query_llm job to the broker and wait for the worker result.""" validate_query_llm_args(EXTERNAL_LLM_MODEL, EXTERNAL_LLM_GPUS, user_text) job = _broker_request( "/api/jobs/query-llm", data={ "model": EXTERNAL_LLM_MODEL, "gpus": str(EXTERNAL_LLM_GPUS), "user_text": user_text, }, method="POST", ) job_id = job["id"] deadline = time.monotonic() + EXTERNAL_LLM_TIMEOUT_SECONDS while time.monotonic() < deadline: job = _broker_request(f"/api/jobs/{job_id}") status = job.get("status") if status == "done": return (job.get("result") or "").strip() if status == "failed": raise RuntimeError(job.get("result") or f"query_llm job {job_id} failed") time.sleep(EXTERNAL_LLM_POLL_INTERVAL_SECONDS) raise TimeoutError( f"Timed out waiting for query_llm job {job_id} after " f"{EXTERNAL_LLM_TIMEOUT_SECONDS} seconds" ) def _render_prompt(messages, adapter=None, documents=None): kwargs = {} if adapter: kwargs["adapter_name"] = adapter if documents: kwargs["documents"] = documents return tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False, **kwargs ) @spaces.GPU def _generate_local(prompt, max_new_tokens=128): m = _get_model() inputs = tokenizer(prompt, return_tensors="pt").to("cuda") with torch.no_grad(): output_ids = m.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False ) return tokenizer.decode( output_ids[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True ).strip() def _generate(messages, adapter=None, documents=None, max_new_tokens=128): """Core generation: render chat template, then use local or external LLM.""" prompt = _render_prompt(messages, adapter=adapter, documents=documents) if USE_EXTERNAL_LLM: return query_llm(prompt, max_new_tokens=max_new_tokens) return _generate_local(prompt, max_new_tokens=max_new_tokens) # --------------------------------------------------------------------------- # Tab: Free Chat (base model, no adapter) # --------------------------------------------------------------------------- def chat_base(message, history, documents, max_tokens): messages = list(history) + [{"role": "user", "content": message}] docs = [{"text": documents.strip()}] if documents and documents.strip() else None return _generate(messages, adapter=None, documents=docs, max_new_tokens=max_tokens) # --------------------------------------------------------------------------- # Tab: Guardian (harm, bias, jailbreak detection) # --------------------------------------------------------------------------- def run_guardian(text_to_evaluate, criteria_choice, custom_criteria, max_tokens): if criteria_choice == "Custom": criteria = custom_criteria else: criteria = GUARDIAN_CRITERIA_BANK[criteria_choice] schema = ( "If the last user's text meets the criteria, return 'yes'; " "otherwise, return 'no'." ) guardian_block = ( f"{JUDGE_SYSTEM}\n\n" f"### Criteria: {criteria}\n\n" f"### Scoring Schema: {schema}" ) messages = [ {"role": "user", "content": text_to_evaluate}, {"role": "user", "content": guardian_block}, ] raw = _generate(messages, adapter="guardian-core", max_new_tokens=max_tokens) try: result = json.loads(raw) score = result.get("score", raw) return f"**Score:** `{score}`\n\n**Raw output:** `{raw}`" except json.JSONDecodeError: return f"**Raw output:** `{raw}`" # --------------------------------------------------------------------------- # Tab: Query Rewrite # --------------------------------------------------------------------------- def run_query_rewrite(query, max_tokens): messages = [{"role": "user", "content": query}] raw = _generate(messages, adapter="query_rewrite", max_new_tokens=max_tokens) return f"**Rewritten query:** {raw}" # --------------------------------------------------------------------------- # Tab: Answerability # --------------------------------------------------------------------------- def run_answerability(question, documents, max_tokens): docs = [{"text": d.strip()} for d in documents.split("\n---\n") if d.strip()] messages = [{"role": "user", "content": question}] raw = _generate(messages, adapter="answerability", documents=docs, max_new_tokens=max_tokens) return f"**Result:** {raw}" # --------------------------------------------------------------------------- # Tab: Citations # --------------------------------------------------------------------------- def run_citations(question, answer, documents, max_tokens): docs = [{"text": d.strip()} for d in documents.split("\n---\n") if d.strip()] messages = [ {"role": "user", "content": question}, {"role": "assistant", "content": answer}, ] raw = _generate(messages, adapter="citations", documents=docs, max_new_tokens=max_tokens) return f"**Citations:** {raw}" # --------------------------------------------------------------------------- # Tab: Hallucination Detection # --------------------------------------------------------------------------- def run_hallucination_detection(question, answer, documents, max_tokens): docs = [{"text": d.strip()} for d in documents.split("\n---\n") if d.strip()] messages = [ {"role": "user", "content": question}, {"role": "assistant", "content": answer}, ] raw = _generate(messages, adapter="hallucination_detection", documents=docs, max_new_tokens=max_tokens) return f"**Result:** {raw}" # --------------------------------------------------------------------------- # Tab: Uncertainty # --------------------------------------------------------------------------- def run_uncertainty(conversation_text, max_tokens): messages = [ {"role": "user", "content": conversation_text}, {"role": "user", "content": ""}, ] raw = _generate(messages, adapter="uncertainty", max_new_tokens=max_tokens) try: result = json.loads(raw) digit = int(result.get("score", 0)) prob = 0.1 * digit + 0.05 return ( f"**Certainty digit:** `{digit}`\n\n" f"**Calibrated probability:** ~{prob*100:.0f}%\n\n" f"**Raw output:** `{raw}`" ) except (json.JSONDecodeError, ValueError): return f"**Raw output:** `{raw}`" # --------------------------------------------------------------------------- # Tab: Requirement Check # --------------------------------------------------------------------------- def run_requirement_check(user_question, assistant_response, requirements, max_tokens): evaluation_prompt = ( "Please verify if the assistant's generation satisfies the user's " "requirements or not and reply with a binary label accordingly. " 'Respond with a json {"score": "yes"} if the constraints are satisfied ' 'or respond with {"score": "no"} if the constraints are not satisfied.' ) req_turn = f" {requirements}\n{evaluation_prompt}" messages = [ {"role": "user", "content": user_question}, {"role": "assistant", "content": assistant_response}, {"role": "user", "content": req_turn}, ] raw = _generate(messages, adapter="requirement-check", max_new_tokens=max_tokens) try: result = json.loads(raw) score = result.get("score", raw) label = "Satisfied" if score == "yes" else "Not satisfied" return f"**{label}** (`{score}`)\n\n**Raw output:** `{raw}`" except json.JSONDecodeError: return f"**Raw output:** `{raw}`" # --------------------------------------------------------------------------- # Tab: Factuality Detection # --------------------------------------------------------------------------- def run_factuality_detection(assistant_response, documents, max_tokens): docs = [{"text": d.strip()} for d in documents.split("\n---\n") if d.strip()] factuality_criteria = ( "A factually incorrect response occurs when the assistant's message " "contains one or more factual claims that are unsupported by, " "inconsistent with, or directly contradicted by the information " "provided in the documents or context." ) schema = ( "If the last assistant's text meets the criteria, return 'yes'; " "otherwise, return 'no'." ) guardian_block = ( f"{JUDGE_SYSTEM}\n\n" f"### Criteria: {factuality_criteria}\n\n" f"### Scoring Schema: {schema}" ) messages = [ {"role": "assistant", "content": assistant_response}, {"role": "user", "content": guardian_block}, ] raw = _generate(messages, adapter="factuality-detection", documents=docs, max_new_tokens=max_tokens) try: result = json.loads(raw) score = result.get("score", raw) label = "Factual errors found" if score == "yes" else "No errors detected" return f"**{label}** (`{score}`)\n\n**Raw output:** `{raw}`" except json.JSONDecodeError: return f"**Raw output:** `{raw}`" # --------------------------------------------------------------------------- # Tab: Factuality Correction # --------------------------------------------------------------------------- def run_factuality_correction(assistant_response, documents, max_tokens): docs = [{"text": d.strip()} for d in documents.split("\n---\n") if d.strip()] factuality_criteria = ( "A factually incorrect response occurs when the assistant's message " "contains one or more factual claims that are unsupported by, " "inconsistent with, or directly contradicted by the information " "provided in the documents or context." ) schema = ( "If the last assistant's text meets the criteria, return a corrected " "version of the assistant's message based on the given context; " "otherwise, return 'none'." ) guardian_block = ( f"{JUDGE_SYSTEM}\n\n" f"### Criteria: {factuality_criteria}\n\n" f"### Scoring Schema: {schema}" ) messages = [ {"role": "assistant", "content": assistant_response}, {"role": "user", "content": guardian_block}, ] raw = _generate(messages, adapter="factuality-correction", documents=docs, max_new_tokens=max_tokens) try: result = json.loads(raw) correction = result.get("correction", raw) if correction == "none": return "**No correction needed.**\n\n" + f"**Raw output:** `{raw}`" return f"**Corrected text:**\n\n{correction}\n\n**Raw output:** `{raw}`" except json.JSONDecodeError: return f"**Raw output:** `{raw}`" # --------------------------------------------------------------------------- # Tab: Policy Guardrails # --------------------------------------------------------------------------- def run_policy_guardrails(scenario, policy, max_tokens): policy_system = ( "You are a compliance agent trying to help determine whether a " "scenario is compliant with a given policy." ) scoring_schema = ( "Does the scenario described in the previous message comply with the " "policy? Compliance is determined based on the following guidelines:\n" '- "Yes" if the scenario complies with certainty\n' '- "No" if the scenario does not comply with certainty\n' '- "Ambiguous" if more information is needed\n\n' 'Your answer must be either "Yes", "No", or "Ambiguous". ' 'Return as JSON: {"label": "Yes"/"No"/"Ambiguous"}.' ) policy_block = ( f" {policy_system}\n\n" f"### Criteria: Policy: {policy}\n\n" f"### Scoring Schema: {scoring_schema}" ) messages = [ {"role": "user", "content": scenario}, {"role": "user", "content": policy_block}, ] raw = _generate(messages, adapter="policy-guardrails", max_new_tokens=max_tokens) try: result = json.loads(raw) label = result.get("label", raw) return f"**Compliance:** `{label}`\n\n**Raw output:** `{raw}`" except json.JSONDecodeError: return f"**Raw output:** `{raw}`" # --------------------------------------------------------------------------- # Tab: Context Attribution # --------------------------------------------------------------------------- def run_context_attribution(question, response, documents, max_tokens): import re docs = [d.strip() for d in documents.split("\n---\n") if d.strip()] def _split_sentences(text): parts = re.split(r"(?<=[.!?])\s+", text.strip()) return [p for p in parts if p] c_counter = 0 tagged_doc_parts = [] for doc in docs: parts = [] for sent in _split_sentences(doc): parts.append(f" {sent}") c_counter += 1 tagged_doc_parts.append({"text": " ".join(parts)}) response_sents = _split_sentences(response) tagged_response = " ".join(f" {s}" for i, s in enumerate(response_sents)) instruction = ( "You provided the last assistant response above based on context, which may " "include documents and/or previous conversation turns. Your response is " "divided into sentences, numbered in the format sentence 0 " "sentence 1 ... Sentences in the context are also numbered: sentence 0 " " sentence 1 ... For each response sentence, please list the context " "sentences that were most important for you to generate the response " "sentence. Provide your answer in JSON format, as an array of JSON objects, " 'where each object has two members: "r" with the response sentence number ' 'as the value, and "c" with an array of context sentence numbers as the ' "value. List the context sentences in order from most important to least " "important. Ensure that you include an object for each response sentence, " "even if the corresponding array of context sentence numbers is empty. " "Answer with only the JSON and do not explain.\n" ) messages = [ {"role": "user", "content": question}, {"role": "assistant", "content": tagged_response}, {"role": "user", "content": instruction}, ] raw = _generate( messages, adapter="context-attribution", documents=tagged_doc_parts, max_new_tokens=max_tokens ) return f"**Attribution:**\n```json\n{raw}\n```" # --------------------------------------------------------------------------- # Build the Gradio UI with tabs per adapter # --------------------------------------------------------------------------- with gr.Blocks(title="Granite Switch 4.1 3B Playground") as demo: gr.Markdown( "# Granite Switch 4.1 3B Playground\n\n" "Interactive demo of [ibm-granite/granite-switch-4.1-3b-preview]" "(https://huggingface.co/ibm-granite/granite-switch-4.1-3b-preview) " "with 12 embedded adapters. Each tab provides the correct prompt " "format for its adapter." ) with gr.Tabs(): # --- Free Chat --- with gr.Tab("Chat (Base Model)"): gr.Markdown("Standard chat with the base model. Optionally provide documents for grounded responses.") chat_interface = gr.ChatInterface( fn=chat_base, additional_inputs=[ gr.Textbox(label="Documents (optional)", lines=4, placeholder="Paste reference documents here..."), gr.Slider(16, 512, value=128, step=16, label="Max new tokens"), ], ) # --- Guardian --- with gr.Tab("Guardian"): gr.Markdown( "**guardian-core** — Evaluate text for harm, bias, jailbreak, etc.\n\n" "Returns `yes` (flagged) or `no` (safe)." ) with gr.Row(): with gr.Column(): guardian_text = gr.Textbox( label="Text to evaluate", lines=3, placeholder="Enter the text you want to check for safety...", ) guardian_criteria = gr.Dropdown( choices=list(GUARDIAN_CRITERIA_BANK.keys()) + ["Custom"], value="harm", label="Criteria", ) guardian_custom = gr.Textbox( label="Custom criteria (if 'Custom' selected above)", lines=2, visible=True, ) guardian_tokens = gr.Slider(16, 64, value=20, step=4, label="Max tokens") guardian_btn = gr.Button("Evaluate", variant="primary") with gr.Column(): guardian_output = gr.Markdown(label="Result") guardian_btn.click( run_guardian, inputs=[guardian_text, guardian_criteria, guardian_custom, guardian_tokens], outputs=guardian_output, ) # --- Query Rewrite --- with gr.Tab("Query Rewrite"): gr.Markdown( "**query_rewrite** — Rewrites messy or verbose queries into clean, search-friendly form." ) with gr.Row(): with gr.Column(): qr_query = gr.Textbox( label="Original query", lines=2, placeholder="e.g., what is...mmmm the main city (capital you call it?) of France?", ) qr_tokens = gr.Slider(16, 256, value=64, step=16, label="Max tokens") qr_btn = gr.Button("Rewrite", variant="primary") with gr.Column(): qr_output = gr.Markdown(label="Result") qr_btn.click(run_query_rewrite, inputs=[qr_query, qr_tokens], outputs=qr_output) # --- Answerability --- with gr.Tab("Answerability"): gr.Markdown( "**answerability** — Can the question be answered from the provided documents?\n\n" "Separate multiple documents with `---` on its own line." ) with gr.Row(): with gr.Column(): ans_question = gr.Textbox(label="Question", lines=2) ans_docs = gr.Textbox( label="Documents (separated by ---)", lines=5, placeholder="Document 1 text...\n---\nDocument 2 text...", ) ans_tokens = gr.Slider(16, 128, value=32, step=16, label="Max tokens") ans_btn = gr.Button("Check", variant="primary") with gr.Column(): ans_output = gr.Markdown(label="Result") ans_btn.click( run_answerability, inputs=[ans_question, ans_docs, ans_tokens], outputs=ans_output, ) # --- Citations --- with gr.Tab("Citations"): gr.Markdown( "**citations** — Find which document passages support a given answer.\n\n" "Separate multiple documents with `---`." ) with gr.Row(): with gr.Column(): cit_question = gr.Textbox(label="Question", lines=2) cit_answer = gr.Textbox(label="Answer to attribute", lines=3) cit_docs = gr.Textbox( label="Documents (separated by ---)", lines=5, ) cit_tokens = gr.Slider(16, 256, value=128, step=16, label="Max tokens") cit_btn = gr.Button("Find Citations", variant="primary") with gr.Column(): cit_output = gr.Markdown(label="Result") cit_btn.click( run_citations, inputs=[cit_question, cit_answer, cit_docs, cit_tokens], outputs=cit_output, ) # --- Hallucination Detection --- with gr.Tab("Hallucination Detection"): gr.Markdown( "**hallucination_detection** — Detect hallucinated content in a response " "relative to source documents.\n\nSeparate documents with `---`." ) with gr.Row(): with gr.Column(): hall_question = gr.Textbox(label="Question", lines=2) hall_answer = gr.Textbox(label="Response to check", lines=3) hall_docs = gr.Textbox(label="Source documents (separated by ---)", lines=5) hall_tokens = gr.Slider(16, 256, value=64, step=16, label="Max tokens") hall_btn = gr.Button("Detect", variant="primary") with gr.Column(): hall_output = gr.Markdown(label="Result") hall_btn.click( run_hallucination_detection, inputs=[hall_question, hall_answer, hall_docs, hall_tokens], outputs=hall_output, ) # --- Uncertainty --- with gr.Tab("Uncertainty"): gr.Markdown( "**uncertainty** — Returns a calibrated confidence digit (0-9) for the " "last assistant response.\n\n" "Digit maps to probability: `0.1 * digit + 0.05` (5% to 95%)." ) with gr.Row(): with gr.Column(): unc_text = gr.Textbox( label="Assistant response to evaluate certainty of", lines=4, placeholder="Paste the response you want to gauge confidence for...", ) unc_tokens = gr.Slider(16, 32, value=20, step=4, label="Max tokens") unc_btn = gr.Button("Check Certainty", variant="primary") with gr.Column(): unc_output = gr.Markdown(label="Result") unc_btn.click(run_uncertainty, inputs=[unc_text, unc_tokens], outputs=unc_output) # --- Requirement Check --- with gr.Tab("Requirement Check"): gr.Markdown( "**requirement-check** — Does the assistant's response satisfy " "stated requirements?\n\nReturns `yes` or `no`." ) with gr.Row(): with gr.Column(): req_question = gr.Textbox(label="User question", lines=2) req_response = gr.Textbox(label="Assistant response", lines=4) req_requirements = gr.Textbox( label="Requirements", lines=3, placeholder="e.g., Must be formal tone. Under 100 words. Must cite sources.", ) req_tokens = gr.Slider(16, 32, value=20, step=4, label="Max tokens") req_btn = gr.Button("Check", variant="primary") with gr.Column(): req_output = gr.Markdown(label="Result") req_btn.click( run_requirement_check, inputs=[req_question, req_response, req_requirements, req_tokens], outputs=req_output, ) # --- Factuality Detection --- with gr.Tab("Factuality Detection"): gr.Markdown( "**factuality-detection** — Check if a response contains factual errors " "vs source documents.\n\nSeparate documents with `---`." ) with gr.Row(): with gr.Column(): fd_response = gr.Textbox(label="Response to check", lines=4) fd_docs = gr.Textbox(label="Source documents (separated by ---)", lines=5) fd_tokens = gr.Slider(16, 32, value=20, step=4, label="Max tokens") fd_btn = gr.Button("Detect", variant="primary") with gr.Column(): fd_output = gr.Markdown(label="Result") fd_btn.click( run_factuality_detection, inputs=[fd_response, fd_docs, fd_tokens], outputs=fd_output, ) # --- Factuality Correction --- with gr.Tab("Factuality Correction"): gr.Markdown( "**factuality-correction** — Correct factual errors in a response " "using source documents.\n\nSeparate documents with `---`." ) with gr.Row(): with gr.Column(): fc_response = gr.Textbox(label="Response to correct", lines=4) fc_docs = gr.Textbox(label="Source documents (separated by ---)", lines=5) fc_tokens = gr.Slider(16, 512, value=256, step=16, label="Max tokens") fc_btn = gr.Button("Correct", variant="primary") with gr.Column(): fc_output = gr.Markdown(label="Result") fc_btn.click( run_factuality_correction, inputs=[fc_response, fc_docs, fc_tokens], outputs=fc_output, ) # --- Policy Guardrails --- with gr.Tab("Policy Guardrails"): gr.Markdown( "**policy-guardrails** — Check if a scenario complies with a policy.\n\n" "Returns `Yes`, `No`, or `Ambiguous`." ) with gr.Row(): with gr.Column(): pol_scenario = gr.Textbox( label="Scenario (text to evaluate)", lines=4, placeholder="The assistant response or action to judge...", ) pol_policy = gr.Textbox( label="Policy", lines=3, placeholder="e.g., Responses must not provide investment advice.", ) pol_tokens = gr.Slider(16, 32, value=20, step=4, label="Max tokens") pol_btn = gr.Button("Evaluate", variant="primary") with gr.Column(): pol_output = gr.Markdown(label="Result") pol_btn.click( run_policy_guardrails, inputs=[pol_scenario, pol_policy, pol_tokens], outputs=pol_output, ) # --- Context Attribution --- with gr.Tab("Context Attribution"): gr.Markdown( "**context-attribution** — Which context sentences supported each " "sentence of the response?\n\nSeparate documents with `---`." ) with gr.Row(): with gr.Column(): ca_question = gr.Textbox(label="Question", lines=2) ca_response = gr.Textbox(label="Response to attribute", lines=4) ca_docs = gr.Textbox(label="Context documents (separated by ---)", lines=5) ca_tokens = gr.Slider(16, 512, value=256, step=16, label="Max tokens") ca_btn = gr.Button("Attribute", variant="primary") with gr.Column(): ca_output = gr.Markdown(label="Result") ca_btn.click( run_context_attribution, inputs=[ca_question, ca_response, ca_docs, ca_tokens], outputs=ca_output, ) #if __name__ == "__main__": # demo.launch() if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)