File size: 7,650 Bytes
9f7707b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""
PrefixGuard Demo - Agent Failure Detection from Traces
Based on: "PrefixGuard: From LLM-Agent Traces to Online Failure-Warning Monitors"
Paper: https://huggingface.co/papers/2605.06455
"""

import gradio as gr
import numpy as np
from typing import List, Tuple, Dict
import json

# Lazy-loaded model state
_risk_model = None
_trace_encoder = None

def get_risk_model():
    """Lazy load risk scoring model"""
    global _risk_model
    if _risk_model is None:
        # Simulated: in production this would load trained PrefixGuard checkpoint
        _risk_model = {
            'step_weights': np.array([0.1, 0.15, 0.2, 0.25, 0.3]),  # Later steps matter more
            'failure_keywords': ['error', 'fail', 'timeout', 'exception', 'invalid', 'denied', 'unable'],
            'success_keywords': ['success', 'completed', 'done', 'result', 'output']
        }
    return _risk_model

def encode_trace_steps(steps: List[str]) -> np.ndarray:
    """Encode agent trace steps to feature vectors"""
    model = get_risk_model()
    features = []
    
    for step in steps:
        step_lower = step.lower()
        # Simple keyword-based features (paper uses learned event abstraction)
        has_fail = any(kw in step_lower for kw in model['failure_keywords'])
        has_success = any(kw in step_lower for kw in model['success_keywords'])
        step_len = len(step)
        has_tool_call = any(x in step_lower for x in ['tool', 'function', 'call', 'api'])
        has_observation = any(x in step_lower for x in ['observation', 'result', 'returned'])
        
        features.append([has_fail, has_success, step_len / 500, has_tool_call, has_observation])
    
    return np.array(features)

def compute_prefix_risk(steps: List[str]) -> Tuple[float, List[float]]:
    """Compute risk score from partial trace prefix"""
    if not steps:
        return 0.5, []
    
    model = get_risk_model()
    features = encode_trace_steps(steps)
    
    step_risks = []
    for i, feat in enumerate(features):
        # Weighted combination (simplified from paper's learned scorer)
        fail_score = feat[0] * 0.8 + feat[2] * 0.1  # failure keywords weight
        success_score = feat[1] * 0.7  # success keywords
        
        # Position weight: later steps contribute more
        pos_weight = model['step_weights'][min(i, len(model['step_weights'])-1)]
        step_risk = (fail_score - success_score * 0.5) * pos_weight
        step_risks.append(max(0, min(1, 0.3 + step_risk)))
    
    # Aggregate: max risk seen so far with recency bias
    if step_risks:
        max_risk = max(step_risks)
        recent_risk = step_risks[-1]
        final_risk = 0.6 * max_risk + 0.4 * recent_risk
    else:
        final_risk = 0.5
    
    return round(final_risk, 3), [round(r, 3) for r in step_risks]

def analyze_trace(trace_text: str) -> Dict:
    """Analyze full agent trace for failure prediction"""
    steps = [s.strip() for s in trace_text.split('\n') if s.strip()]
    
    if len(steps) < 2:
        return {
            "error": "Please provide at least 2 trace steps (one per line)"
        }
    
    # Compute risk at each prefix length
    prefix_results = []
    for i in range(1, len(steps) + 1):
        prefix = steps[:i]
        risk, step_risks = compute_prefix_risk(prefix)
        prefix_results.append({
            "step": i,
            "risk_score": risk,
            "alert": "⚠️ HIGH RISK" if risk > 0.7 else ("⚡ MEDIUM" if risk > 0.5 else "✅ LOW"),
            "content_preview": prefix[-1][:80] + "..." if len(prefix[-1]) > 80 else prefix[-1]
        })
    
    final_risk = prefix_results[-1]["risk_score"]
    final_outcome = "FAILURE" if final_risk > 0.6 else "SUCCESS"
    early_warning_step = None
    for i, res in enumerate(prefix_results):
        if res["risk_score"] > 0.7:
            early_warning_step = i + 1
            break
    
    return {
        "total_steps": len(steps),
        "final_risk": final_risk,
        "predicted_outcome": final_outcome,
        "early_warning_at_step": early_warning_step,
        "prefix_analysis": prefix_results
    }

def demo_interface():
    """Gradio interface for PrefixGuard demo"""
    
    def process_trace(trace_text):
        result = analyze_trace(trace_text)
        
        if "error" in result:
            return result["error"], "", ""
        
        # Build summary
        summary = f"""## Analysis Results

**Total Steps:** {result['total_steps']}
**Final Risk Score:** {result['final_risk']:.3f}
**Predicted Outcome:** {result['predicted_outcome']}
**Early Warning:** Step {result['early_warning_at_step']} (if any)

### Key Insight
This demonstrates how PrefixGuard predicts failures from partial traces,
enabling intervention before task completion."""
        
        # Build step-by-step table
        table = "| Step | Risk | Alert | Preview |\n|------|------|-------|---------|\n"
        for r in result['prefix_analysis']:
            table += f"| {r['step']} | {r['risk_score']:.3f} | {r['alert']} | {r['content_preview']} |\n"
        
        # Risk progression
        risks = [r['risk_score'] for r in result['prefix_analysis']]
        risk_chart = "Risk progression: " + " → ".join([f"{r:.2f}" for r in risks])
        
        return summary, table, risk_chart
    
    with gr.Blocks(title="PrefixGuard Demo") as demo:
        gr.Markdown("""# 🛡️ PrefixGuard Demo
        
**Agent Failure Detection from Execution Traces**
        
Based on: *"PrefixGuard: From LLM-Agent Traces to Online Failure-Warning Monitors"* (Huang et al., 2026)
        
Enter agent execution steps (one per line) to see how prefix-based monitoring predicts failures.""")
        
        with gr.Row():
            with gr.Column(scale=2):
                trace_input = gr.Textbox(
                    label="Agent Trace Steps (one per line)",
                    placeholder="Step 1: Calling search tool...\nStep 2: Tool returned error...\nStep 3: Retrying with...",
                    lines=10
                )
                analyze_btn = gr.Button("Analyze Trace", variant="primary")
                
                # Example traces
                gr.Examples(
                    examples=[
                        ["Tool: search_web\\nObservation: 5 results found\\nTool: click_result\\nObservation: Page loaded\\nTool: extract_data\\nObservation: Success: extracted 3 records"],
                        ["Tool: api_call\\nObservation: Error 500 internal server error\\nTool: retry\\nObservation: Error timeout\\nTool: fallback\\nObservation: Unable to complete"],
                        ["Step 1: Initializing agent\\nStep 2: Planning task execution\\nStep 3: Tool call failed with exception\\nStep 4: Error propagation detected"]
                    ],
                    inputs=[trace_input],
                    label="Example Traces"
                )
            
            with gr.Column(scale=3):
                summary_out = gr.Markdown(label="Summary")
                table_out = gr.Markdown(label="Step-by-Step Analysis")
                chart_out = gr.Textbox(label="Risk Progression", interactive=False)
        
        analyze_btn.click(
            fn=process_trace,
            inputs=[trace_input],
            outputs=[summary_out, table_out, chart_out]
        )
        
        gr.Markdown("""---
**Note:** This is a simplified demonstration. The full PrefixGuard paper achieves 0.900 AUPRC on WebArena 
using learned event abstractions and finite-state monitors trained on terminal outcomes.""")
    
    return demo

if __name__ == "__main__":
    demo = demo_interface()
    demo.launch()