riprap-nyc / app /mellea_validator.py
seriffic's picture
deploy: sync all changes from main at 6904684
b9a10ad
"""Mellea-validated reconciliation for Riprap.
Wraps the existing Granite-via-Ollama reconciliation in IBM Research's
Mellea framework: typed output + programmatic post-conditions +
rejection sampling. Replaces post-hoc sentence-dropping with
"don't accept output until requirements pass."
Streaming and rejection sampling are mutually exclusive — by the time
we'd validate, the user has watched the bad output appear. Strict mode
trades streaming for compliance; the UI shows a "validating" skeleton
instead of token-by-token render.
The four invariants ported from the parent project's mellea_probe:
1. no_invented_numbers — every number in output appears in source
2. no_placeholder_tokens — output never contains "[source]" or
raw <document> markup
3. every_claim_cited — each numeric token has a [doc_id] within
~40 chars
4. referenced_doc_ids_exist — cited doc_ids ⊆ input doc_ids
"""
from __future__ import annotations
import logging
import os
import re
import time
from typing import Any
from mellea import start_session
from mellea.stdlib.requirements import req, simple_validate
from mellea.stdlib.sampling import RejectionSamplingStrategy
from app import llm
log = logging.getLogger("riprap.mellea")
# Default reconciler model — same env-var contract as app/reconcile.py.
DEFAULT_MODEL = os.environ.get(
"RIPRAP_RECONCILER_MODEL",
os.environ.get("RIPRAP_OLLAMA_MODEL", "granite4.1:8b"),
)
# Loop budget — try up to N samples before falling back to the last
# candidate even if it didn't pass all requirements. Low ceiling so a
# pathological case can't run away with latency.
#
# Override at process start with RIPRAP_MELLEA_MAX_ATTEMPTS. We default
# to 2 on the local Ollama path (where each attempt is 30-90 s on the
# Mac) and 3 on remote/vLLM (where attempts are seconds). This caps
# worst-case demo latency without giving up the principal grounding
# guarantee — the first-attempt pass rate on the curated probes is >85%.
def _default_loop_budget() -> int:
try:
n = int(os.environ.get("RIPRAP_MELLEA_MAX_ATTEMPTS", "0"))
if n > 0:
return n
except ValueError:
pass
return 2 if os.environ.get("RIPRAP_LLM_PRIMARY", "ollama").lower() == "ollama" else 3
DEFAULT_LOOP_BUDGET = _default_loop_budget()
# Number tokens — \b enforces a word boundary so identifier codes like
# QN1206, B12 (community board), or M14 (bus route) are skipped entirely.
# Inside QN1206 there's no \b between any chars, so no submatch leaks.
_NUM_RE = re.compile(r"\b-?\d[\d,]*(?:\.\d+)?\b")
_CITE_RE = re.compile(r"\[(?P<id>[a-z][a-z0-9_]*)\]")
# Same trivial-numbers list as the post-hoc verifier — well-known service
# line numbers, single digits.
_TRIVIAL_NUMS = {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "100",
"311", "911", "211"}
def _strip_markdown_for_check(text: str) -> str:
"""Drop bold markers + citation tags so the numeric scan is clean."""
text = re.sub(r"\*\*([^*]+)\*\*", r"\1", text)
text = re.sub(r"\[[a-z0-9_]+\]", "", text, flags=re.I)
return text
def _normalize_num(s: str) -> set[str]:
forms = {s}
no_comma = s.replace(",", "")
forms.add(no_comma)
if "." in no_comma:
forms.add(no_comma.rstrip("0").rstrip("."))
return {f for f in forms if f}
def _haystack(doc_msgs: list[dict]) -> str:
return "\n".join(m.get("content", "") for m in doc_msgs)
def _doc_ids(doc_msgs: list[dict]) -> set[str]:
"""Each doc message has role like "document <id>"; extract ids."""
out = set()
for m in doc_msgs:
role = m.get("role", "")
if role.startswith("document "):
out.add(role.split(" ", 1)[1].strip())
return out
# --- the four invariants ---------------------------------------------------
def _check_no_invented_numbers(doc_msgs: list[dict]):
haystack = _haystack(doc_msgs)
def _fn(text: str):
clean = _strip_markdown_for_check(text)
invented = []
for n in _NUM_RE.findall(clean):
if n in _TRIVIAL_NUMS:
continue
forms = _normalize_num(n)
if not any(f in haystack for f in forms):
invented.append(n)
return not invented # pass = no invented numbers
return _fn
def _check_no_placeholder_tokens():
def _fn(text: str):
bad = []
if "[source]" in text.lower():
bad.append("[source]")
if "<document" in text:
bad.append("<document>")
if "</document" in text:
bad.append("</document>")
if "[doc_id]" in text:
# Model echoed the EXTRA_SYSTEM_PROMPT skeleton literally
bad.append("[doc_id]")
return not bad
return _fn
def _check_every_claim_cited():
"""Each non-trivial numeric token must have a [doc_id] somewhere in
the same sentence. Sentence boundaries are conservative: a period
followed by whitespace, or end of text. This matches how a reader
actually attributes claims — the citation can be anywhere in the
sentence, not just adjacent to the number."""
# Sentence end = `. ` or `.\n` or end-of-string. Question/exclamation
# marks rarely appear in these briefings; period is enough.
_SENT_END = re.compile(r"\.[\s)]|\.$")
def _sentence_span(text: str, pos: int) -> tuple[int, int]:
# Walk backwards to the previous sentence terminator.
start = 0
for m in _SENT_END.finditer(text, 0, pos):
start = m.end()
# Walk forwards to the next.
m = _SENT_END.search(text, pos)
end = m.start() + 1 if m else len(text)
return start, end
def _fn(text: str):
clean = re.sub(r"\*\*([^*]+)\*\*", r"\1", text)
for m in _NUM_RE.finditer(clean):
n = m.group(0)
if n in _TRIVIAL_NUMS:
continue
s, e = _sentence_span(clean, m.start())
if not _CITE_RE.search(clean[s:e]):
return False
return True
return _fn
def _failing_sentences_for_citations(text: str) -> list[str]:
"""Return the sentences in `text` that contain a non-trivial number
but no [doc_id] citation. Used to give the model targeted reroll
feedback so it can fix the exact spots that failed."""
clean = re.sub(r"\*\*([^*]+)\*\*", r"\1", text)
sents = re.split(r"\.[\s)]|\.$", clean)
bad = []
for s in sents:
nums = [n for n in _NUM_RE.findall(s) if n not in _TRIVIAL_NUMS]
if nums and not _CITE_RE.search(s):
bad.append(s)
return bad
def _check_referenced_doc_ids_exist(doc_msgs: list[dict]):
valid = _doc_ids(doc_msgs)
def _fn(text: str):
cited = {m.group("id") for m in _CITE_RE.finditer(text)}
rogue = cited - valid
return not rogue
return _fn
# --- main entry point ------------------------------------------------------
def reconcile_strict(doc_msgs: list[dict],
system_prompt: str,
user_prompt: str = "Write the cited briefing now.",
model: str | None = None,
loop_budget: int = DEFAULT_LOOP_BUDGET,
ollama_options: dict | None = None) -> dict[str, Any]:
"""Run Granite reconciliation with Mellea rejection sampling.
Returns a dict with:
paragraph — final validated text
rerolls — number of resamples (0 = passed first try)
requirements_passed — list of requirement names that passed in the
accepted sample
requirements_failed — list of requirement names that failed
(empty on accepted sample)
elapsed_s — total seconds including rerolls
model — model id used
loop_budget — configured budget
"""
model = model or DEFAULT_MODEL
t0 = time.time()
# Per-requirement closures wired with the doc context.
# Keep the validator functions in our own table so we can re-run them
# on the final paragraph to produce reliable pass/fail metadata for
# the report — Mellea's internal validation-result objects vary by
# version and aren't great for downstream display.
checks = [
("numerics_grounded",
"All numbers in the output must appear verbatim in the source documents.",
_check_no_invented_numbers(doc_msgs)),
("no_placeholder_tokens",
"The output must not contain placeholder tokens like [source] or raw <document> markup.",
_check_no_placeholder_tokens()),
("citations_dense",
"Every numeric claim must have a [doc_id] citation within ~120 characters.",
_check_every_claim_cited()),
("citations_resolve",
"Every cited [doc_id] must correspond to a real source document.",
_check_referenced_doc_ids_exist(doc_msgs)),
]
requirements = [
req(desc, validation_fn=simple_validate(fn, reason=name))
for name, desc, fn in checks
]
session = start_session(backend_name="ollama", model_id=model,
model_options=ollama_options or {})
try:
# Build the prompt: system + serialized doc context + user task.
# Mellea's instruct() takes the whole instruction; we serialize
# the doc messages into the description so the haystack is
# available to the model the same way it would be via
# ollama.chat with role="document <id>" messages.
doc_block = "\n\n".join(
f"<document id=\"{m['role'].split(' ', 1)[1] if m['role'].startswith('document ') else 'unknown'}\">\n"
f"{m['content']}\n</document>"
for m in doc_msgs
)
instruction = (
f"{system_prompt}\n\n"
f"DOCUMENTS:\n{doc_block}\n\n"
f"TASK: {user_prompt}"
)
result = session.instruct(
description=instruction,
strategy=RejectionSamplingStrategy(
loop_budget=loop_budget,
requirements=requirements,
),
requirements=requirements,
return_sampling_results=True,
model_options={"temperature": 0,
"num_ctx": int(os.environ.get("RIPRAP_MELLEA_NUM_CTX", "4096")),
"num_predict": int(os.environ.get("RIPRAP_MELLEA_NUM_PREDICT", "400")),
**(ollama_options or {})},
)
paragraph = _extract_text(result).strip()
n_attempts = _extract_attempts(result)
rerolls = max(0, n_attempts - 1)
finally:
try:
session.cleanup()
except Exception:
pass
# Re-run our own checks on the final paragraph for clean pass/fail
# metadata. This is what shows up in the report's compliance section.
passed: list[str] = []
failed: list[str] = []
for name, _desc, fn in checks:
try:
if fn(paragraph):
passed.append(name)
else:
failed.append(name)
except Exception as e:
log.warning("requirement %s raised: %r", name, e)
failed.append(name)
return {
"paragraph": paragraph,
"rerolls": rerolls,
"n_attempts": n_attempts,
"requirements_total": len(checks),
"requirements_passed": passed,
"requirements_failed": failed,
"elapsed_s": round(time.time() - t0, 2),
"model": model,
"loop_budget": loop_budget,
}
def reconcile_strict_streaming(
doc_msgs: list[dict],
system_prompt: str,
user_prompt: str = "Write the cited briefing now.",
model: str | None = None,
loop_budget: int = DEFAULT_LOOP_BUDGET,
ollama_options: dict | None = None,
on_token=None,
on_attempt_end=None,
) -> dict[str, Any]:
"""Hand-rolled rejection sampler that *streams* each attempt to the
user instead of waiting silently for Mellea to validate behind the
scenes. Same compliance contract as reconcile_strict — runs the
same four checks, accepts the first attempt that passes, falls back
to the last attempt if the budget is exhausted.
Callbacks (both optional, both fire on the calling thread):
on_token(delta: str, attempt_idx: int)
— fires for every token chunk as it arrives from Granite.
on_attempt_end(attempt_idx: int, passed: list[str], failed: list[str])
— fires after each attempt with its per-requirement outcome.
The frontend uses this to render reroll banners + reset the
briefing buffer when a new attempt begins.
"""
model = model or DEFAULT_MODEL
t0 = time.time()
checks = [
("numerics_grounded",
_check_no_invented_numbers(doc_msgs)),
("no_placeholder_tokens",
_check_no_placeholder_tokens()),
("citations_dense",
_check_every_claim_cited()),
("citations_resolve",
_check_referenced_doc_ids_exist(doc_msgs)),
]
base_messages = doc_msgs + [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
# num_ctx 4096 fits a typical trimmed prompt (≈700 system + ≈2500 docs);
# num_predict 400 caps the 4-section briefing at ≈300-350 tokens. With
# RIPRAP_TRIM_DOCS=1 and the planner picking 6-9 specialists, the 4096
# window has been sufficient on every probe; the previous 6144/600 was
# sized for the *untrimmed* fan-out and was forcing Ollama to grow the
# KV cache (33% more memory + a full re-init) every Mellea attempt.
# Override with RIPRAP_MELLEA_NUM_CTX / RIPRAP_MELLEA_NUM_PREDICT.
base_opts = {"temperature": 0,
"num_ctx": int(os.environ.get("RIPRAP_MELLEA_NUM_CTX", "4096")),
"num_predict": int(os.environ.get("RIPRAP_MELLEA_NUM_PREDICT", "400")),
**(ollama_options or {})}
paragraph = ""
last_passed: list[str] = []
last_failed: list[str] = [name for name, _ in checks]
last_paragraph = ""
attempts = 0
for attempt_idx in range(loop_budget):
attempts = attempt_idx + 1
# On reroll, append a tight feedback message naming what failed AND
# the specific failing sentences (so the model knows exactly which
# ones to fix). Granite responds well to surgical corrections.
messages = list(base_messages)
if attempt_idx > 0 and last_failed:
feedback = [
f"Your previous draft failed: {', '.join(last_failed)}.",
]
if "citations_dense" in last_failed and last_paragraph:
bad = _failing_sentences_for_citations(last_paragraph)
if bad:
feedback.append(
"Specific sentences with uncited numbers:"
)
for s in bad[:3]:
feedback.append(f" - {s.strip()}")
feedback.append(
"Add a [doc_id] citation at the end of each. "
"Re-emit the FULL briefing."
)
else:
feedback.append(
"Re-write so every sentence containing a number ends "
"with a [doc_id] citation."
)
messages.append({"role": "user", "content": "\n".join(feedback)})
chunks: list[str] = []
for chunk in llm.chat(model=model, messages=messages,
stream=True, options=base_opts):
delta = (chunk.get("message") or {}).get("content") or ""
if delta:
chunks.append(delta)
if on_token is not None:
try:
on_token(delta, attempt_idx)
except Exception:
log.exception("on_token callback raised")
paragraph = "".join(chunks).strip()
passed: list[str] = []
failed: list[str] = []
for name, fn in checks:
try:
(passed if fn(paragraph) else failed).append(name)
except Exception as e:
log.warning("requirement %s raised: %r", name, e)
failed.append(name)
last_passed = passed
last_failed = failed
last_paragraph = paragraph
if on_attempt_end is not None:
try:
on_attempt_end(attempt_idx, passed, failed)
except Exception:
log.exception("on_attempt_end callback raised")
if not failed:
break
return {
"paragraph": paragraph,
"rerolls": max(0, attempts - 1),
"n_attempts": attempts,
"requirements_total": len(checks),
"requirements_passed": last_passed,
"requirements_failed": last_failed,
"elapsed_s": round(time.time() - t0, 2),
"model": model,
"loop_budget": loop_budget,
}
def _extract_text(result) -> str:
"""SamplingResult / ModelOutputThunk text extraction."""
for attr in ("sample", "result", "value", "content"):
v = getattr(result, attr, None)
if v is not None:
if hasattr(v, "value"):
return str(v.value)
return str(v)
return str(result)
def _extract_attempts(result) -> int:
"""How many samples were drawn before stopping."""
for attr in ("n_attempts", "num_attempts", "attempts"):
v = getattr(result, attr, None)
if isinstance(v, int):
return v
samples = getattr(result, "sample_validations", None) or getattr(result, "samples", None)
if isinstance(samples, list):
return len(samples)
return 1
def _extract_pass_fail(result) -> tuple[list[str], list[str]]:
"""Best-effort extraction of which requirements passed on the
accepted sample. mellea v0.4 exposes sample_validations as a list
where each entry is itself a list of (Requirement, ValidationResult)
tuples — duck-type defensively.
"""
validations = getattr(result, "sample_validations", None)
if not validations:
return [], []
last = validations[-1] if isinstance(validations, list) else validations
passed: list[str] = []
failed: list[str] = []
items = last if isinstance(last, list) else [last]
for item in items:
# Item might be (Requirement, ValidationResult) tuple, or a single
# ValidationResult, or a Requirement, depending on mellea version.
ok = None
descr = ""
if isinstance(item, tuple) and len(item) >= 2:
descr = str(item[0])[:80]
v = item[1]
ok = bool(getattr(v, "passed", getattr(v, "is_valid",
getattr(v, "result", False))))
else:
descr = str(getattr(item, "requirement", item))[:80]
ok = bool(getattr(item, "passed", getattr(item, "is_valid",
getattr(item, "result", False))))
if ok:
passed.append(descr)
else:
failed.append(descr)
return passed, failed