Spaces:
Running
Running
| """ | |
| PrefixGuard Demo - Agent Failure Detection from Traces | |
| Based on: "PrefixGuard: From LLM-Agent Traces to Online Failure-Warning Monitors" | |
| Paper: https://huggingface.co/papers/2605.06455 | |
| """ | |
| import gradio as gr | |
| import numpy as np | |
| from typing import List, Tuple, Dict | |
| import json | |
| # Lazy-loaded model state | |
| _risk_model = None | |
| _trace_encoder = None | |
| def get_risk_model(): | |
| """Lazy load risk scoring model""" | |
| global _risk_model | |
| if _risk_model is None: | |
| # Simulated: in production this would load trained PrefixGuard checkpoint | |
| _risk_model = { | |
| 'step_weights': np.array([0.1, 0.15, 0.2, 0.25, 0.3]), # Later steps matter more | |
| 'failure_keywords': ['error', 'fail', 'timeout', 'exception', 'invalid', 'denied', 'unable'], | |
| 'success_keywords': ['success', 'completed', 'done', 'result', 'output'] | |
| } | |
| return _risk_model | |
| def encode_trace_steps(steps: List[str]) -> np.ndarray: | |
| """Encode agent trace steps to feature vectors""" | |
| model = get_risk_model() | |
| features = [] | |
| for step in steps: | |
| step_lower = step.lower() | |
| # Simple keyword-based features (paper uses learned event abstraction) | |
| has_fail = any(kw in step_lower for kw in model['failure_keywords']) | |
| has_success = any(kw in step_lower for kw in model['success_keywords']) | |
| step_len = len(step) | |
| has_tool_call = any(x in step_lower for x in ['tool', 'function', 'call', 'api']) | |
| has_observation = any(x in step_lower for x in ['observation', 'result', 'returned']) | |
| features.append([has_fail, has_success, step_len / 500, has_tool_call, has_observation]) | |
| return np.array(features) | |
| def compute_prefix_risk(steps: List[str]) -> Tuple[float, List[float]]: | |
| """Compute risk score from partial trace prefix""" | |
| if not steps: | |
| return 0.5, [] | |
| model = get_risk_model() | |
| features = encode_trace_steps(steps) | |
| step_risks = [] | |
| for i, feat in enumerate(features): | |
| # Weighted combination (simplified from paper's learned scorer) | |
| fail_score = feat[0] * 0.8 + feat[2] * 0.1 # failure keywords weight | |
| success_score = feat[1] * 0.7 # success keywords | |
| # Position weight: later steps contribute more | |
| pos_weight = model['step_weights'][min(i, len(model['step_weights'])-1)] | |
| step_risk = (fail_score - success_score * 0.5) * pos_weight | |
| step_risks.append(max(0, min(1, 0.3 + step_risk))) | |
| # Aggregate: max risk seen so far with recency bias | |
| if step_risks: | |
| max_risk = max(step_risks) | |
| recent_risk = step_risks[-1] | |
| final_risk = 0.6 * max_risk + 0.4 * recent_risk | |
| else: | |
| final_risk = 0.5 | |
| return round(final_risk, 3), [round(r, 3) for r in step_risks] | |
| def analyze_trace(trace_text: str) -> Dict: | |
| """Analyze full agent trace for failure prediction""" | |
| steps = [s.strip() for s in trace_text.split('\n') if s.strip()] | |
| if len(steps) < 2: | |
| return { | |
| "error": "Please provide at least 2 trace steps (one per line)" | |
| } | |
| # Compute risk at each prefix length | |
| prefix_results = [] | |
| for i in range(1, len(steps) + 1): | |
| prefix = steps[:i] | |
| risk, step_risks = compute_prefix_risk(prefix) | |
| prefix_results.append({ | |
| "step": i, | |
| "risk_score": risk, | |
| "alert": "⚠️ HIGH RISK" if risk > 0.7 else ("⚡ MEDIUM" if risk > 0.5 else "✅ LOW"), | |
| "content_preview": prefix[-1][:80] + "..." if len(prefix[-1]) > 80 else prefix[-1] | |
| }) | |
| final_risk = prefix_results[-1]["risk_score"] | |
| final_outcome = "FAILURE" if final_risk > 0.6 else "SUCCESS" | |
| early_warning_step = None | |
| for i, res in enumerate(prefix_results): | |
| if res["risk_score"] > 0.7: | |
| early_warning_step = i + 1 | |
| break | |
| return { | |
| "total_steps": len(steps), | |
| "final_risk": final_risk, | |
| "predicted_outcome": final_outcome, | |
| "early_warning_at_step": early_warning_step, | |
| "prefix_analysis": prefix_results | |
| } | |
| def demo_interface(): | |
| """Gradio interface for PrefixGuard demo""" | |
| def process_trace(trace_text): | |
| result = analyze_trace(trace_text) | |
| if "error" in result: | |
| return result["error"], "", "" | |
| # Build summary | |
| summary = f"""## Analysis Results | |
| **Total Steps:** {result['total_steps']} | |
| **Final Risk Score:** {result['final_risk']:.3f} | |
| **Predicted Outcome:** {result['predicted_outcome']} | |
| **Early Warning:** Step {result['early_warning_at_step']} (if any) | |
| ### Key Insight | |
| This demonstrates how PrefixGuard predicts failures from partial traces, | |
| enabling intervention before task completion.""" | |
| # Build step-by-step table | |
| table = "| Step | Risk | Alert | Preview |\n|------|------|-------|---------|\n" | |
| for r in result['prefix_analysis']: | |
| table += f"| {r['step']} | {r['risk_score']:.3f} | {r['alert']} | {r['content_preview']} |\n" | |
| # Risk progression | |
| risks = [r['risk_score'] for r in result['prefix_analysis']] | |
| risk_chart = "Risk progression: " + " → ".join([f"{r:.2f}" for r in risks]) | |
| return summary, table, risk_chart | |
| with gr.Blocks(title="PrefixGuard Demo") as demo: | |
| gr.Markdown("""# 🛡️ PrefixGuard Demo | |
| **Agent Failure Detection from Execution Traces** | |
| Based on: *"PrefixGuard: From LLM-Agent Traces to Online Failure-Warning Monitors"* (Huang et al., 2026) | |
| Enter agent execution steps (one per line) to see how prefix-based monitoring predicts failures.""") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| trace_input = gr.Textbox( | |
| label="Agent Trace Steps (one per line)", | |
| placeholder="Step 1: Calling search tool...\nStep 2: Tool returned error...\nStep 3: Retrying with...", | |
| lines=10 | |
| ) | |
| analyze_btn = gr.Button("Analyze Trace", variant="primary") | |
| # Example traces | |
| gr.Examples( | |
| examples=[ | |
| ["Tool: search_web\\nObservation: 5 results found\\nTool: click_result\\nObservation: Page loaded\\nTool: extract_data\\nObservation: Success: extracted 3 records"], | |
| ["Tool: api_call\\nObservation: Error 500 internal server error\\nTool: retry\\nObservation: Error timeout\\nTool: fallback\\nObservation: Unable to complete"], | |
| ["Step 1: Initializing agent\\nStep 2: Planning task execution\\nStep 3: Tool call failed with exception\\nStep 4: Error propagation detected"] | |
| ], | |
| inputs=[trace_input], | |
| label="Example Traces" | |
| ) | |
| with gr.Column(scale=3): | |
| summary_out = gr.Markdown(label="Summary") | |
| table_out = gr.Markdown(label="Step-by-Step Analysis") | |
| chart_out = gr.Textbox(label="Risk Progression", interactive=False) | |
| analyze_btn.click( | |
| fn=process_trace, | |
| inputs=[trace_input], | |
| outputs=[summary_out, table_out, chart_out] | |
| ) | |
| gr.Markdown("""--- | |
| **Note:** This is a simplified demonstration. The full PrefixGuard paper achieves 0.900 AUPRC on WebArena | |
| using learned event abstractions and finite-state monitors trained on terminal outcomes.""") | |
| return demo | |
| if __name__ == "__main__": | |
| demo = demo_interface() | |
| demo.launch() | |