| """Mellea-validated reconciliation for Riprap. |
| |
| Wraps the existing Granite-via-Ollama reconciliation in IBM Research's |
| Mellea framework: typed output + programmatic post-conditions + |
| rejection sampling. Replaces post-hoc sentence-dropping with |
| "don't accept output until requirements pass." |
| |
| Streaming and rejection sampling are mutually exclusive — by the time |
| we'd validate, the user has watched the bad output appear. Strict mode |
| trades streaming for compliance; the UI shows a "validating" skeleton |
| instead of token-by-token render. |
| |
| The four invariants ported from the parent project's mellea_probe: |
| |
| 1. no_invented_numbers — every number in output appears in source |
| 2. no_placeholder_tokens — output never contains "[source]" or |
| raw <document> markup |
| 3. every_claim_cited — each numeric token has a [doc_id] within |
| ~40 chars |
| 4. referenced_doc_ids_exist — cited doc_ids ⊆ input doc_ids |
| """ |
| from __future__ import annotations |
|
|
| import logging |
| import os |
| import re |
| import time |
| from typing import Any |
|
|
| from mellea import start_session |
| from mellea.stdlib.requirements import req, simple_validate |
| from mellea.stdlib.sampling import RejectionSamplingStrategy |
|
|
| from app import llm |
|
|
| log = logging.getLogger("riprap.mellea") |
|
|
| |
| DEFAULT_MODEL = os.environ.get( |
| "RIPRAP_RECONCILER_MODEL", |
| os.environ.get("RIPRAP_OLLAMA_MODEL", "granite4.1:8b"), |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def _default_loop_budget() -> int: |
| try: |
| n = int(os.environ.get("RIPRAP_MELLEA_MAX_ATTEMPTS", "0")) |
| if n > 0: |
| return n |
| except ValueError: |
| pass |
| return 2 if os.environ.get("RIPRAP_LLM_PRIMARY", "ollama").lower() == "ollama" else 3 |
|
|
|
|
| DEFAULT_LOOP_BUDGET = _default_loop_budget() |
|
|
| |
| |
| |
| _NUM_RE = re.compile(r"\b-?\d[\d,]*(?:\.\d+)?\b") |
| _CITE_RE = re.compile(r"\[(?P<id>[a-z][a-z0-9_]*)\]") |
| |
| |
| _TRIVIAL_NUMS = {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "100", |
| "311", "911", "211"} |
|
|
|
|
| def _strip_markdown_for_check(text: str) -> str: |
| """Drop bold markers + citation tags so the numeric scan is clean.""" |
| text = re.sub(r"\*\*([^*]+)\*\*", r"\1", text) |
| text = re.sub(r"\[[a-z0-9_]+\]", "", text, flags=re.I) |
| return text |
|
|
|
|
| def _normalize_num(s: str) -> set[str]: |
| forms = {s} |
| no_comma = s.replace(",", "") |
| forms.add(no_comma) |
| if "." in no_comma: |
| forms.add(no_comma.rstrip("0").rstrip(".")) |
| return {f for f in forms if f} |
|
|
|
|
| def _haystack(doc_msgs: list[dict]) -> str: |
| return "\n".join(m.get("content", "") for m in doc_msgs) |
|
|
|
|
| def _doc_ids(doc_msgs: list[dict]) -> set[str]: |
| """Each doc message has role like "document <id>"; extract ids.""" |
| out = set() |
| for m in doc_msgs: |
| role = m.get("role", "") |
| if role.startswith("document "): |
| out.add(role.split(" ", 1)[1].strip()) |
| return out |
|
|
|
|
| |
|
|
|
|
| def _check_no_invented_numbers(doc_msgs: list[dict]): |
| haystack = _haystack(doc_msgs) |
| def _fn(text: str): |
| clean = _strip_markdown_for_check(text) |
| invented = [] |
| for n in _NUM_RE.findall(clean): |
| if n in _TRIVIAL_NUMS: |
| continue |
| forms = _normalize_num(n) |
| if not any(f in haystack for f in forms): |
| invented.append(n) |
| return not invented |
| return _fn |
|
|
|
|
| def _check_no_placeholder_tokens(): |
| def _fn(text: str): |
| bad = [] |
| if "[source]" in text.lower(): |
| bad.append("[source]") |
| if "<document" in text: |
| bad.append("<document>") |
| if "</document" in text: |
| bad.append("</document>") |
| if "[doc_id]" in text: |
| |
| bad.append("[doc_id]") |
| return not bad |
| return _fn |
|
|
|
|
| def _check_every_claim_cited(): |
| """Each non-trivial numeric token must have a [doc_id] somewhere in |
| the same sentence. Sentence boundaries are conservative: a period |
| followed by whitespace, or end of text. This matches how a reader |
| actually attributes claims — the citation can be anywhere in the |
| sentence, not just adjacent to the number.""" |
| |
| |
| _SENT_END = re.compile(r"\.[\s)]|\.$") |
|
|
| def _sentence_span(text: str, pos: int) -> tuple[int, int]: |
| |
| start = 0 |
| for m in _SENT_END.finditer(text, 0, pos): |
| start = m.end() |
| |
| m = _SENT_END.search(text, pos) |
| end = m.start() + 1 if m else len(text) |
| return start, end |
|
|
| def _fn(text: str): |
| clean = re.sub(r"\*\*([^*]+)\*\*", r"\1", text) |
| for m in _NUM_RE.finditer(clean): |
| n = m.group(0) |
| if n in _TRIVIAL_NUMS: |
| continue |
| s, e = _sentence_span(clean, m.start()) |
| if not _CITE_RE.search(clean[s:e]): |
| return False |
| return True |
| return _fn |
|
|
|
|
| def _failing_sentences_for_citations(text: str) -> list[str]: |
| """Return the sentences in `text` that contain a non-trivial number |
| but no [doc_id] citation. Used to give the model targeted reroll |
| feedback so it can fix the exact spots that failed.""" |
| clean = re.sub(r"\*\*([^*]+)\*\*", r"\1", text) |
| sents = re.split(r"\.[\s)]|\.$", clean) |
| bad = [] |
| for s in sents: |
| nums = [n for n in _NUM_RE.findall(s) if n not in _TRIVIAL_NUMS] |
| if nums and not _CITE_RE.search(s): |
| bad.append(s) |
| return bad |
|
|
|
|
| def _check_referenced_doc_ids_exist(doc_msgs: list[dict]): |
| valid = _doc_ids(doc_msgs) |
| def _fn(text: str): |
| cited = {m.group("id") for m in _CITE_RE.finditer(text)} |
| rogue = cited - valid |
| return not rogue |
| return _fn |
|
|
|
|
| |
|
|
|
|
| def reconcile_strict(doc_msgs: list[dict], |
| system_prompt: str, |
| user_prompt: str = "Write the cited briefing now.", |
| model: str | None = None, |
| loop_budget: int = DEFAULT_LOOP_BUDGET, |
| ollama_options: dict | None = None) -> dict[str, Any]: |
| """Run Granite reconciliation with Mellea rejection sampling. |
| |
| Returns a dict with: |
| paragraph — final validated text |
| rerolls — number of resamples (0 = passed first try) |
| requirements_passed — list of requirement names that passed in the |
| accepted sample |
| requirements_failed — list of requirement names that failed |
| (empty on accepted sample) |
| elapsed_s — total seconds including rerolls |
| model — model id used |
| loop_budget — configured budget |
| """ |
| model = model or DEFAULT_MODEL |
| t0 = time.time() |
|
|
| |
| |
| |
| |
| |
| checks = [ |
| ("numerics_grounded", |
| "All numbers in the output must appear verbatim in the source documents.", |
| _check_no_invented_numbers(doc_msgs)), |
| ("no_placeholder_tokens", |
| "The output must not contain placeholder tokens like [source] or raw <document> markup.", |
| _check_no_placeholder_tokens()), |
| ("citations_dense", |
| "Every numeric claim must have a [doc_id] citation within ~120 characters.", |
| _check_every_claim_cited()), |
| ("citations_resolve", |
| "Every cited [doc_id] must correspond to a real source document.", |
| _check_referenced_doc_ids_exist(doc_msgs)), |
| ] |
| requirements = [ |
| req(desc, validation_fn=simple_validate(fn, reason=name)) |
| for name, desc, fn in checks |
| ] |
|
|
| session = start_session(backend_name="ollama", model_id=model, |
| model_options=ollama_options or {}) |
| try: |
| |
| |
| |
| |
| |
| doc_block = "\n\n".join( |
| f"<document id=\"{m['role'].split(' ', 1)[1] if m['role'].startswith('document ') else 'unknown'}\">\n" |
| f"{m['content']}\n</document>" |
| for m in doc_msgs |
| ) |
| instruction = ( |
| f"{system_prompt}\n\n" |
| f"DOCUMENTS:\n{doc_block}\n\n" |
| f"TASK: {user_prompt}" |
| ) |
|
|
| result = session.instruct( |
| description=instruction, |
| strategy=RejectionSamplingStrategy( |
| loop_budget=loop_budget, |
| requirements=requirements, |
| ), |
| requirements=requirements, |
| return_sampling_results=True, |
| model_options={"temperature": 0, |
| "num_ctx": int(os.environ.get("RIPRAP_MELLEA_NUM_CTX", "4096")), |
| "num_predict": int(os.environ.get("RIPRAP_MELLEA_NUM_PREDICT", "400")), |
| **(ollama_options or {})}, |
| ) |
|
|
| paragraph = _extract_text(result).strip() |
| n_attempts = _extract_attempts(result) |
| rerolls = max(0, n_attempts - 1) |
| finally: |
| try: |
| session.cleanup() |
| except Exception: |
| pass |
|
|
| |
| |
| passed: list[str] = [] |
| failed: list[str] = [] |
| for name, _desc, fn in checks: |
| try: |
| if fn(paragraph): |
| passed.append(name) |
| else: |
| failed.append(name) |
| except Exception as e: |
| log.warning("requirement %s raised: %r", name, e) |
| failed.append(name) |
|
|
| return { |
| "paragraph": paragraph, |
| "rerolls": rerolls, |
| "n_attempts": n_attempts, |
| "requirements_total": len(checks), |
| "requirements_passed": passed, |
| "requirements_failed": failed, |
| "elapsed_s": round(time.time() - t0, 2), |
| "model": model, |
| "loop_budget": loop_budget, |
| } |
|
|
|
|
| def reconcile_strict_streaming( |
| doc_msgs: list[dict], |
| system_prompt: str, |
| user_prompt: str = "Write the cited briefing now.", |
| model: str | None = None, |
| loop_budget: int = DEFAULT_LOOP_BUDGET, |
| ollama_options: dict | None = None, |
| on_token=None, |
| on_attempt_end=None, |
| ) -> dict[str, Any]: |
| """Hand-rolled rejection sampler that *streams* each attempt to the |
| user instead of waiting silently for Mellea to validate behind the |
| scenes. Same compliance contract as reconcile_strict — runs the |
| same four checks, accepts the first attempt that passes, falls back |
| to the last attempt if the budget is exhausted. |
| |
| Callbacks (both optional, both fire on the calling thread): |
| on_token(delta: str, attempt_idx: int) |
| — fires for every token chunk as it arrives from Granite. |
| on_attempt_end(attempt_idx: int, passed: list[str], failed: list[str]) |
| — fires after each attempt with its per-requirement outcome. |
| The frontend uses this to render reroll banners + reset the |
| briefing buffer when a new attempt begins. |
| """ |
| model = model or DEFAULT_MODEL |
| t0 = time.time() |
|
|
| checks = [ |
| ("numerics_grounded", |
| _check_no_invented_numbers(doc_msgs)), |
| ("no_placeholder_tokens", |
| _check_no_placeholder_tokens()), |
| ("citations_dense", |
| _check_every_claim_cited()), |
| ("citations_resolve", |
| _check_referenced_doc_ids_exist(doc_msgs)), |
| ] |
|
|
| base_messages = doc_msgs + [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt}, |
| ] |
| |
| |
| |
| |
| |
| |
| |
| base_opts = {"temperature": 0, |
| "num_ctx": int(os.environ.get("RIPRAP_MELLEA_NUM_CTX", "4096")), |
| "num_predict": int(os.environ.get("RIPRAP_MELLEA_NUM_PREDICT", "400")), |
| **(ollama_options or {})} |
|
|
| paragraph = "" |
| last_passed: list[str] = [] |
| last_failed: list[str] = [name for name, _ in checks] |
| last_paragraph = "" |
| attempts = 0 |
|
|
| for attempt_idx in range(loop_budget): |
| attempts = attempt_idx + 1 |
| |
| |
| |
| messages = list(base_messages) |
| if attempt_idx > 0 and last_failed: |
| feedback = [ |
| f"Your previous draft failed: {', '.join(last_failed)}.", |
| ] |
| if "citations_dense" in last_failed and last_paragraph: |
| bad = _failing_sentences_for_citations(last_paragraph) |
| if bad: |
| feedback.append( |
| "Specific sentences with uncited numbers:" |
| ) |
| for s in bad[:3]: |
| feedback.append(f" - {s.strip()}") |
| feedback.append( |
| "Add a [doc_id] citation at the end of each. " |
| "Re-emit the FULL briefing." |
| ) |
| else: |
| feedback.append( |
| "Re-write so every sentence containing a number ends " |
| "with a [doc_id] citation." |
| ) |
| messages.append({"role": "user", "content": "\n".join(feedback)}) |
|
|
| chunks: list[str] = [] |
| for chunk in llm.chat(model=model, messages=messages, |
| stream=True, options=base_opts): |
| delta = (chunk.get("message") or {}).get("content") or "" |
| if delta: |
| chunks.append(delta) |
| if on_token is not None: |
| try: |
| on_token(delta, attempt_idx) |
| except Exception: |
| log.exception("on_token callback raised") |
| paragraph = "".join(chunks).strip() |
|
|
| passed: list[str] = [] |
| failed: list[str] = [] |
| for name, fn in checks: |
| try: |
| (passed if fn(paragraph) else failed).append(name) |
| except Exception as e: |
| log.warning("requirement %s raised: %r", name, e) |
| failed.append(name) |
|
|
| last_passed = passed |
| last_failed = failed |
| last_paragraph = paragraph |
| if on_attempt_end is not None: |
| try: |
| on_attempt_end(attempt_idx, passed, failed) |
| except Exception: |
| log.exception("on_attempt_end callback raised") |
|
|
| if not failed: |
| break |
|
|
| return { |
| "paragraph": paragraph, |
| "rerolls": max(0, attempts - 1), |
| "n_attempts": attempts, |
| "requirements_total": len(checks), |
| "requirements_passed": last_passed, |
| "requirements_failed": last_failed, |
| "elapsed_s": round(time.time() - t0, 2), |
| "model": model, |
| "loop_budget": loop_budget, |
| } |
|
|
|
|
| def _extract_text(result) -> str: |
| """SamplingResult / ModelOutputThunk text extraction.""" |
| for attr in ("sample", "result", "value", "content"): |
| v = getattr(result, attr, None) |
| if v is not None: |
| if hasattr(v, "value"): |
| return str(v.value) |
| return str(v) |
| return str(result) |
|
|
|
|
| def _extract_attempts(result) -> int: |
| """How many samples were drawn before stopping.""" |
| for attr in ("n_attempts", "num_attempts", "attempts"): |
| v = getattr(result, attr, None) |
| if isinstance(v, int): |
| return v |
| samples = getattr(result, "sample_validations", None) or getattr(result, "samples", None) |
| if isinstance(samples, list): |
| return len(samples) |
| return 1 |
|
|
|
|
| def _extract_pass_fail(result) -> tuple[list[str], list[str]]: |
| """Best-effort extraction of which requirements passed on the |
| accepted sample. mellea v0.4 exposes sample_validations as a list |
| where each entry is itself a list of (Requirement, ValidationResult) |
| tuples — duck-type defensively. |
| """ |
| validations = getattr(result, "sample_validations", None) |
| if not validations: |
| return [], [] |
| last = validations[-1] if isinstance(validations, list) else validations |
| passed: list[str] = [] |
| failed: list[str] = [] |
| items = last if isinstance(last, list) else [last] |
| for item in items: |
| |
| |
| ok = None |
| descr = "" |
| if isinstance(item, tuple) and len(item) >= 2: |
| descr = str(item[0])[:80] |
| v = item[1] |
| ok = bool(getattr(v, "passed", getattr(v, "is_valid", |
| getattr(v, "result", False)))) |
| else: |
| descr = str(getattr(item, "requirement", item))[:80] |
| ok = bool(getattr(item, "passed", getattr(item, "is_valid", |
| getattr(item, "result", False)))) |
| if ok: |
| passed.append(descr) |
| else: |
| failed.append(descr) |
| return passed, failed |
|
|