File size: 7,382 Bytes
658c9d5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | """
hardening.py β Defensive utilities for production-grade execution.
Applied across the critical path to prevent crashes from:
- None propagation
- LLM timeouts
- Malformed parser output
- Environment exceptions
- Type mismatches at boundaries
Usage:
from purpose_agent.hardening import safe_params, llm_call_with_timeout, safe_action
All functions are pure β no side effects, no state.
"""
from __future__ import annotations
import logging
import signal
import threading
import time
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeout
from typing import Any, Callable, TypeVar
logger = logging.getLogger("purpose_agent.hardening")
T = TypeVar("T")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Null Safety
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def safe_params(params: Any) -> dict[str, Any]:
"""
Normalize action params to a guaranteed dict.
Handles: None, string, list, or any non-dict garbage from parsers.
"""
if isinstance(params, dict):
return params
if params is None:
return {}
if isinstance(params, str):
# Parser sometimes returns the raw string
return {"_raw": params}
return {}
def safe_string(value: Any, default: str = "", max_len: int = 10000) -> str:
"""Guarantee a string value. Never returns None."""
if value is None:
return default
s = str(value)
return s[:max_len] if len(s) > max_len else s
def safe_float(value: Any, default: float = 0.0, min_val: float = 0.0, max_val: float = 10.0) -> float:
"""Guarantee a bounded float. Never raises, never returns None."""
try:
f = float(str(value).rstrip('.').rstrip(','))
return max(min_val, min(max_val, f))
except (ValueError, TypeError):
return default
def safe_dict_get(d: Any, key: str, default: Any = "") -> Any:
"""Safe get from potentially-None dict."""
if not isinstance(d, dict):
return default
val = d.get(key, default)
return val if val is not None else default
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Timeout Wrapper
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def with_timeout(fn: Callable[..., T], timeout_s: float = 30.0, default: T = None, label: str = "") -> Callable[..., T]:
"""
Wrap a function with a timeout. Returns default if timeout exceeded.
Uses ThreadPoolExecutor (works on all platforms, no signals needed).
Usage:
safe_generate = with_timeout(llm.generate, timeout_s=30.0, default="", label="llm.generate")
result = safe_generate(messages, temperature=0.7)
"""
def wrapper(*args, **kwargs) -> T:
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(fn, *args, **kwargs)
try:
return future.result(timeout=timeout_s)
except FuturesTimeout:
logger.error(f"TIMEOUT ({timeout_s}s): {label or fn.__name__}")
return default
except Exception as e:
logger.error(f"ERROR in {label or fn.__name__}: {type(e).__name__}: {e}")
return default
return wrapper
def llm_call_with_timeout(
llm_fn: Callable,
args: tuple = (),
kwargs: dict | None = None,
timeout_s: float = 60.0,
default: str = "",
label: str = "llm_call",
) -> str:
"""
Execute a single LLM call with timeout and error recovery.
Returns default string on any failure (timeout, network error, parse error).
NEVER raises to the caller.
"""
kwargs = kwargs or {}
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(llm_fn, *args, **kwargs)
try:
result = future.result(timeout=timeout_s)
if result is None:
return default
return str(result)
except FuturesTimeout:
logger.error(f"LLM TIMEOUT ({timeout_s}s): {label}")
return default
except Exception as e:
logger.error(f"LLM ERROR ({label}): {type(e).__name__}: {e}")
return default
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Graceful Degradation
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def graceful(fn: Callable[..., T], default: T, label: str = "") -> Callable[..., T]:
"""
Decorator: function never raises. Returns default on any exception.
Usage:
@graceful(default={}, label="parse_response")
def parse_response(text): ...
"""
def wrapper(*args, **kwargs) -> T:
try:
result = fn(*args, **kwargs)
return result if result is not None else default
except Exception as e:
logger.warning(f"Graceful degradation ({label or fn.__name__}): {type(e).__name__}: {e}")
return default
wrapper.__name__ = fn.__name__
wrapper.__doc__ = fn.__doc__
return wrapper
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Input Validation
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class ValidationError(ValueError):
"""Raised when framework input validation fails. Always has actionable message."""
pass
def validate_purpose(purpose: str) -> str:
"""Validate and normalize a purpose string."""
if not purpose or not isinstance(purpose, str):
raise ValidationError(
"purpose must be a non-empty string. "
"Example: pa.purpose('Help me write Python code')"
)
purpose = purpose.strip()
if len(purpose) < 3:
raise ValidationError(
f"purpose too short ({len(purpose)} chars). "
"Provide a meaningful description of what you want the agent to do."
)
if len(purpose) > 5000:
purpose = purpose[:5000]
logger.warning("Purpose truncated to 5000 chars")
return purpose
def validate_model_spec(spec: str) -> str:
"""Validate a model spec string."""
if not spec or not isinstance(spec, str):
raise ValidationError(
"model must be a string like 'ollama:qwen3:1.7b' or 'openrouter:meta-llama/llama-3.3-70b-instruct'. "
"See docs for supported providers."
)
return spec.strip()
|