"""Mellea-validated reconciliation for Riprap. Wraps the existing Granite-via-Ollama reconciliation in IBM Research's Mellea framework: typed output + programmatic post-conditions + rejection sampling. Replaces post-hoc sentence-dropping with "don't accept output until requirements pass." Streaming and rejection sampling are mutually exclusive — by the time we'd validate, the user has watched the bad output appear. Strict mode trades streaming for compliance; the UI shows a "validating" skeleton instead of token-by-token render. The four invariants ported from the parent project's mellea_probe: 1. no_invented_numbers — every number in output appears in source 2. no_placeholder_tokens — output never contains "[source]" or raw markup 3. every_claim_cited — each numeric token has a [doc_id] within ~40 chars 4. referenced_doc_ids_exist — cited doc_ids ⊆ input doc_ids """ from __future__ import annotations import logging import os import re import time from typing import Any from mellea import start_session from mellea.stdlib.requirements import req, simple_validate from mellea.stdlib.sampling import RejectionSamplingStrategy from app import llm log = logging.getLogger("riprap.mellea") # Default reconciler model — same env-var contract as app/reconcile.py. DEFAULT_MODEL = os.environ.get( "RIPRAP_RECONCILER_MODEL", os.environ.get("RIPRAP_OLLAMA_MODEL", "granite4.1:8b"), ) # Loop budget — try up to N samples before falling back to the last # candidate even if it didn't pass all requirements. Low ceiling so a # pathological case can't run away with latency. # # Override at process start with RIPRAP_MELLEA_MAX_ATTEMPTS. We default # to 2 on the local Ollama path (where each attempt is 30-90 s on the # Mac) and 3 on remote/vLLM (where attempts are seconds). This caps # worst-case demo latency without giving up the principal grounding # guarantee — the first-attempt pass rate on the curated probes is >85%. def _default_loop_budget() -> int: try: n = int(os.environ.get("RIPRAP_MELLEA_MAX_ATTEMPTS", "0")) if n > 0: return n except ValueError: pass return 2 if os.environ.get("RIPRAP_LLM_PRIMARY", "ollama").lower() == "ollama" else 3 DEFAULT_LOOP_BUDGET = _default_loop_budget() # Number tokens — \b enforces a word boundary so identifier codes like # QN1206, B12 (community board), or M14 (bus route) are skipped entirely. # Inside QN1206 there's no \b between any chars, so no submatch leaks. _NUM_RE = re.compile(r"\b-?\d[\d,]*(?:\.\d+)?\b") _CITE_RE = re.compile(r"\[(?P[a-z][a-z0-9_]*)\]") # Same trivial-numbers list as the post-hoc verifier — well-known service # line numbers, single digits. _TRIVIAL_NUMS = {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "100", "311", "911", "211"} def _strip_markdown_for_check(text: str) -> str: """Drop bold markers + citation tags so the numeric scan is clean.""" text = re.sub(r"\*\*([^*]+)\*\*", r"\1", text) text = re.sub(r"\[[a-z0-9_]+\]", "", text, flags=re.I) return text def _normalize_num(s: str) -> set[str]: forms = {s} no_comma = s.replace(",", "") forms.add(no_comma) if "." in no_comma: forms.add(no_comma.rstrip("0").rstrip(".")) return {f for f in forms if f} def _haystack(doc_msgs: list[dict]) -> str: return "\n".join(m.get("content", "") for m in doc_msgs) def _doc_ids(doc_msgs: list[dict]) -> set[str]: """Each doc message has role like "document "; extract ids.""" out = set() for m in doc_msgs: role = m.get("role", "") if role.startswith("document "): out.add(role.split(" ", 1)[1].strip()) return out # --- the four invariants --------------------------------------------------- def _check_no_invented_numbers(doc_msgs: list[dict]): haystack = _haystack(doc_msgs) def _fn(text: str): clean = _strip_markdown_for_check(text) invented = [] for n in _NUM_RE.findall(clean): if n in _TRIVIAL_NUMS: continue forms = _normalize_num(n) if not any(f in haystack for f in forms): invented.append(n) return not invented # pass = no invented numbers return _fn def _check_no_placeholder_tokens(): def _fn(text: str): bad = [] if "[source]" in text.lower(): bad.append("[source]") if "") if "") if "[doc_id]" in text: # Model echoed the EXTRA_SYSTEM_PROMPT skeleton literally bad.append("[doc_id]") return not bad return _fn def _check_every_claim_cited(): """Each non-trivial numeric token must have a [doc_id] somewhere in the same sentence. Sentence boundaries are conservative: a period followed by whitespace, or end of text. This matches how a reader actually attributes claims — the citation can be anywhere in the sentence, not just adjacent to the number.""" # Sentence end = `. ` or `.\n` or end-of-string. Question/exclamation # marks rarely appear in these briefings; period is enough. _SENT_END = re.compile(r"\.[\s)]|\.$") def _sentence_span(text: str, pos: int) -> tuple[int, int]: # Walk backwards to the previous sentence terminator. start = 0 for m in _SENT_END.finditer(text, 0, pos): start = m.end() # Walk forwards to the next. m = _SENT_END.search(text, pos) end = m.start() + 1 if m else len(text) return start, end def _fn(text: str): clean = re.sub(r"\*\*([^*]+)\*\*", r"\1", text) for m in _NUM_RE.finditer(clean): n = m.group(0) if n in _TRIVIAL_NUMS: continue s, e = _sentence_span(clean, m.start()) if not _CITE_RE.search(clean[s:e]): return False return True return _fn def _failing_sentences_for_citations(text: str) -> list[str]: """Return the sentences in `text` that contain a non-trivial number but no [doc_id] citation. Used to give the model targeted reroll feedback so it can fix the exact spots that failed.""" clean = re.sub(r"\*\*([^*]+)\*\*", r"\1", text) sents = re.split(r"\.[\s)]|\.$", clean) bad = [] for s in sents: nums = [n for n in _NUM_RE.findall(s) if n not in _TRIVIAL_NUMS] if nums and not _CITE_RE.search(s): bad.append(s) return bad def _check_referenced_doc_ids_exist(doc_msgs: list[dict]): valid = _doc_ids(doc_msgs) def _fn(text: str): cited = {m.group("id") for m in _CITE_RE.finditer(text)} rogue = cited - valid return not rogue return _fn # --- main entry point ------------------------------------------------------ def reconcile_strict(doc_msgs: list[dict], system_prompt: str, user_prompt: str = "Write the cited briefing now.", model: str | None = None, loop_budget: int = DEFAULT_LOOP_BUDGET, ollama_options: dict | None = None) -> dict[str, Any]: """Run Granite reconciliation with Mellea rejection sampling. Returns a dict with: paragraph — final validated text rerolls — number of resamples (0 = passed first try) requirements_passed — list of requirement names that passed in the accepted sample requirements_failed — list of requirement names that failed (empty on accepted sample) elapsed_s — total seconds including rerolls model — model id used loop_budget — configured budget """ model = model or DEFAULT_MODEL t0 = time.time() # Per-requirement closures wired with the doc context. # Keep the validator functions in our own table so we can re-run them # on the final paragraph to produce reliable pass/fail metadata for # the report — Mellea's internal validation-result objects vary by # version and aren't great for downstream display. checks = [ ("numerics_grounded", "All numbers in the output must appear verbatim in the source documents.", _check_no_invented_numbers(doc_msgs)), ("no_placeholder_tokens", "The output must not contain placeholder tokens like [source] or raw markup.", _check_no_placeholder_tokens()), ("citations_dense", "Every numeric claim must have a [doc_id] citation within ~120 characters.", _check_every_claim_cited()), ("citations_resolve", "Every cited [doc_id] must correspond to a real source document.", _check_referenced_doc_ids_exist(doc_msgs)), ] requirements = [ req(desc, validation_fn=simple_validate(fn, reason=name)) for name, desc, fn in checks ] session = start_session(backend_name="ollama", model_id=model, model_options=ollama_options or {}) try: # Build the prompt: system + serialized doc context + user task. # Mellea's instruct() takes the whole instruction; we serialize # the doc messages into the description so the haystack is # available to the model the same way it would be via # ollama.chat with role="document " messages. doc_block = "\n\n".join( f"\n" f"{m['content']}\n" for m in doc_msgs ) instruction = ( f"{system_prompt}\n\n" f"DOCUMENTS:\n{doc_block}\n\n" f"TASK: {user_prompt}" ) result = session.instruct( description=instruction, strategy=RejectionSamplingStrategy( loop_budget=loop_budget, requirements=requirements, ), requirements=requirements, return_sampling_results=True, model_options={"temperature": 0, "num_ctx": int(os.environ.get("RIPRAP_MELLEA_NUM_CTX", "4096")), "num_predict": int(os.environ.get("RIPRAP_MELLEA_NUM_PREDICT", "400")), **(ollama_options or {})}, ) paragraph = _extract_text(result).strip() n_attempts = _extract_attempts(result) rerolls = max(0, n_attempts - 1) finally: try: session.cleanup() except Exception: pass # Re-run our own checks on the final paragraph for clean pass/fail # metadata. This is what shows up in the report's compliance section. passed: list[str] = [] failed: list[str] = [] for name, _desc, fn in checks: try: if fn(paragraph): passed.append(name) else: failed.append(name) except Exception as e: log.warning("requirement %s raised: %r", name, e) failed.append(name) return { "paragraph": paragraph, "rerolls": rerolls, "n_attempts": n_attempts, "requirements_total": len(checks), "requirements_passed": passed, "requirements_failed": failed, "elapsed_s": round(time.time() - t0, 2), "model": model, "loop_budget": loop_budget, } def reconcile_strict_streaming( doc_msgs: list[dict], system_prompt: str, user_prompt: str = "Write the cited briefing now.", model: str | None = None, loop_budget: int = DEFAULT_LOOP_BUDGET, ollama_options: dict | None = None, on_token=None, on_attempt_end=None, ) -> dict[str, Any]: """Hand-rolled rejection sampler that *streams* each attempt to the user instead of waiting silently for Mellea to validate behind the scenes. Same compliance contract as reconcile_strict — runs the same four checks, accepts the first attempt that passes, falls back to the last attempt if the budget is exhausted. Callbacks (both optional, both fire on the calling thread): on_token(delta: str, attempt_idx: int) — fires for every token chunk as it arrives from Granite. on_attempt_end(attempt_idx: int, passed: list[str], failed: list[str]) — fires after each attempt with its per-requirement outcome. The frontend uses this to render reroll banners + reset the briefing buffer when a new attempt begins. """ model = model or DEFAULT_MODEL t0 = time.time() checks = [ ("numerics_grounded", _check_no_invented_numbers(doc_msgs)), ("no_placeholder_tokens", _check_no_placeholder_tokens()), ("citations_dense", _check_every_claim_cited()), ("citations_resolve", _check_referenced_doc_ids_exist(doc_msgs)), ] base_messages = doc_msgs + [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] # num_ctx 4096 fits a typical trimmed prompt (≈700 system + ≈2500 docs); # num_predict 400 caps the 4-section briefing at ≈300-350 tokens. With # RIPRAP_TRIM_DOCS=1 and the planner picking 6-9 specialists, the 4096 # window has been sufficient on every probe; the previous 6144/600 was # sized for the *untrimmed* fan-out and was forcing Ollama to grow the # KV cache (33% more memory + a full re-init) every Mellea attempt. # Override with RIPRAP_MELLEA_NUM_CTX / RIPRAP_MELLEA_NUM_PREDICT. base_opts = {"temperature": 0, "num_ctx": int(os.environ.get("RIPRAP_MELLEA_NUM_CTX", "4096")), "num_predict": int(os.environ.get("RIPRAP_MELLEA_NUM_PREDICT", "400")), **(ollama_options or {})} paragraph = "" last_passed: list[str] = [] last_failed: list[str] = [name for name, _ in checks] last_paragraph = "" attempts = 0 for attempt_idx in range(loop_budget): attempts = attempt_idx + 1 # On reroll, append a tight feedback message naming what failed AND # the specific failing sentences (so the model knows exactly which # ones to fix). Granite responds well to surgical corrections. messages = list(base_messages) if attempt_idx > 0 and last_failed: feedback = [ f"Your previous draft failed: {', '.join(last_failed)}.", ] if "citations_dense" in last_failed and last_paragraph: bad = _failing_sentences_for_citations(last_paragraph) if bad: feedback.append( "Specific sentences with uncited numbers:" ) for s in bad[:3]: feedback.append(f" - {s.strip()}") feedback.append( "Add a [doc_id] citation at the end of each. " "Re-emit the FULL briefing." ) else: feedback.append( "Re-write so every sentence containing a number ends " "with a [doc_id] citation." ) messages.append({"role": "user", "content": "\n".join(feedback)}) chunks: list[str] = [] for chunk in llm.chat(model=model, messages=messages, stream=True, options=base_opts): delta = (chunk.get("message") or {}).get("content") or "" if delta: chunks.append(delta) if on_token is not None: try: on_token(delta, attempt_idx) except Exception: log.exception("on_token callback raised") paragraph = "".join(chunks).strip() passed: list[str] = [] failed: list[str] = [] for name, fn in checks: try: (passed if fn(paragraph) else failed).append(name) except Exception as e: log.warning("requirement %s raised: %r", name, e) failed.append(name) last_passed = passed last_failed = failed last_paragraph = paragraph if on_attempt_end is not None: try: on_attempt_end(attempt_idx, passed, failed) except Exception: log.exception("on_attempt_end callback raised") if not failed: break return { "paragraph": paragraph, "rerolls": max(0, attempts - 1), "n_attempts": attempts, "requirements_total": len(checks), "requirements_passed": last_passed, "requirements_failed": last_failed, "elapsed_s": round(time.time() - t0, 2), "model": model, "loop_budget": loop_budget, } def _extract_text(result) -> str: """SamplingResult / ModelOutputThunk text extraction.""" for attr in ("sample", "result", "value", "content"): v = getattr(result, attr, None) if v is not None: if hasattr(v, "value"): return str(v.value) return str(v) return str(result) def _extract_attempts(result) -> int: """How many samples were drawn before stopping.""" for attr in ("n_attempts", "num_attempts", "attempts"): v = getattr(result, attr, None) if isinstance(v, int): return v samples = getattr(result, "sample_validations", None) or getattr(result, "samples", None) if isinstance(samples, list): return len(samples) return 1 def _extract_pass_fail(result) -> tuple[list[str], list[str]]: """Best-effort extraction of which requirements passed on the accepted sample. mellea v0.4 exposes sample_validations as a list where each entry is itself a list of (Requirement, ValidationResult) tuples — duck-type defensively. """ validations = getattr(result, "sample_validations", None) if not validations: return [], [] last = validations[-1] if isinstance(validations, list) else validations passed: list[str] = [] failed: list[str] = [] items = last if isinstance(last, list) else [last] for item in items: # Item might be (Requirement, ValidationResult) tuple, or a single # ValidationResult, or a Requirement, depending on mellea version. ok = None descr = "" if isinstance(item, tuple) and len(item) >= 2: descr = str(item[0])[:80] v = item[1] ok = bool(getattr(v, "passed", getattr(v, "is_valid", getattr(v, "result", False)))) else: descr = str(getattr(item, "requirement", item))[:80] ok = bool(getattr(item, "passed", getattr(item, "is_valid", getattr(item, "result", False)))) if ok: passed.append(descr) else: failed.append(descr) return passed, failed