| import logging |
| import time |
| import torch |
| from typing import Dict, Optional, Tuple |
|
|
| from .model import BitTransformerLM |
|
|
|
|
| class SafetyGate: |
| """Exponential moving average safety gate with burn-in.""" |
|
|
| def __init__( |
| self, |
| *, |
| c_floor: float = 0.3, |
| s_floor: float = 0.5, |
| decay: float = 0.9, |
| burn_in: int = 10, |
| ) -> None: |
| self.c_floor = c_floor |
| self.s_floor = s_floor |
| self.decay = decay |
| self.burn_in = burn_in |
| self.step = 0 |
| self._c_ema: Optional[float] = None |
| self._s_ema: Optional[float] = None |
|
|
| def should_trigger(self, c_val: float, s_val: float) -> bool: |
| """Update EMA scores and check if gating should trigger.""" |
|
|
| self.step += 1 |
| if self._c_ema is None: |
| self._c_ema = c_val |
| self._s_ema = s_val |
| else: |
| self._c_ema = self.decay * self._c_ema + (1 - self.decay) * c_val |
| self._s_ema = self.decay * self._s_ema + (1 - self.decay) * s_val |
| if self.step <= self.burn_in: |
| return False |
| return self._c_ema <= self.c_floor or self._s_ema <= self.s_floor |
|
|
|
|
| def hil_safe_inference( |
| model: BitTransformerLM, |
| bit_seq: torch.Tensor, |
| c_floor: float = 0.3, |
| s_floor: float = 0.5, |
| *, |
| causal: bool = True, |
| strict: bool = True, |
| gate: Optional[SafetyGate] = None, |
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
| """Run inference with telemetry gating. |
| |
| Parameters |
| ---------- |
| model: |
| Model to run inference with. |
| bit_seq: |
| Input bit sequences. |
| c_floor, s_floor: |
| Minimum LZ complexity and symbiosis score required for safe output. |
| causal: |
| Whether to run the model in causal (autoregressive) mode. When ``False`` |
| the model performs full-context Diffusion LM inference. |
| strict: |
| If ``False`` the function returns model outputs even when the floors are |
| not met instead of raising ``RuntimeError``. |
| gate: |
| Optional :class:`SafetyGate` that applies EMA smoothing and burn-in |
| before enforcing the floors. |
| """ |
| model.eval() |
| with torch.no_grad(): |
| logits, telemetry = model(bit_seq, causal=causal) |
| c_val = float(telemetry["lz_complexity_logits"].mean().item()) |
| s_val = float(telemetry["symbiosis_score"].mean().item()) |
| c_val = max(0.0, min(1.0, c_val)) |
| s_val = max(0.0, min(1.0, s_val)) |
| if gate is not None: |
| triggered = gate.should_trigger(c_val, s_val) |
| else: |
| triggered = c_val <= c_floor or s_val <= s_floor |
| if strict and triggered: |
| raise RuntimeError( |
| f"Safety gate triggered: C={c_val:.3f}, S={s_val:.3f}" |
| ) |
| return logits.argmax(-1), telemetry |
|
|
|
|
| def demo_hil_safety() -> None: |
| """Demonstrate gating on random bits.""" |
| bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
| model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
| try: |
| out, _ = hil_safe_inference(model, bits, c_floor=0.0, s_floor=0.0) |
| print("Safe output bits:", out.squeeze(0).tolist()) |
| except RuntimeError as e: |
| print("Gate triggered:", e) |
|
|
|
|
| def safe_sample_with_retry( |
| model: BitTransformerLM, |
| bit_seq: torch.Tensor, |
| c_floor: float = 0.3, |
| s_floor: float = 0.5, |
| *, |
| causal: bool = True, |
| max_retries: int = 3, |
| backoff: float = 0.1, |
| gate: Optional[SafetyGate] = None, |
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
| """Run :func:`hil_safe_inference` with automatic retries. |
| |
| The helper retries failed safety checks by toggling diffusion mode and |
| refreshing the input bits. An exponential backoff is applied between |
| attempts and warnings are logged for each retry. |
| |
| Parameters |
| ---------- |
| gate: |
| Optional :class:`SafetyGate` instance shared across retries to apply |
| EMA smoothing and burn-in. |
| |
| Returns |
| ------- |
| Tuple[torch.Tensor, Dict[str, torch.Tensor]] |
| The sampled bits and associated telemetry. |
| """ |
|
|
| for attempt in range(max_retries): |
| try: |
| return hil_safe_inference( |
| model, |
| bit_seq, |
| c_floor, |
| s_floor, |
| causal=causal, |
| strict=True, |
| gate=gate, |
| ) |
| except RuntimeError as exc: |
| logging.warning("Safety gate failed (attempt %d/%d): %s", attempt + 1, max_retries, exc) |
| if attempt >= max_retries - 1: |
| raise |
| time.sleep(backoff * (2 ** attempt)) |
| causal = False |
| bit_seq = torch.randint(0, 2, bit_seq.shape, dtype=bit_seq.dtype, device=bit_seq.device) |
|
|
|
|