File size: 4,208 Bytes
21c7db9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Parser for structured policy decisions.

Provides strict XML parsing, soft parsing fallback, and repair-backed parsing.
"""

from __future__ import annotations

import json
import xml.etree.ElementTree as ET

from app.common.exceptions import ParserError
from app.models.policy.output_schema import DecisionSchema
from app.models.policy.repair import repair_partial_decision

_REQUIRED_XML_FIELDS = {
    "mode",
    "action_type",
    "target_drug",
    "replacement_drug",
    "dose_bucket",
    "taper_days",
    "candidate_id",
    "confidence",
}

_ALLOWED_XML_FIELDS = _REQUIRED_XML_FIELDS | {"monitoring_plan", "rationale_brief"}


def _normalize_scalar(value: object) -> object:
    if value is None:
        return None
    if isinstance(value, str) and value.strip().lower() in {"", "none", "null", "na"}:
        return None
    return value


def _coerce_payload_types(payload: dict[str, object]) -> dict[str, object]:
    coerced = dict(payload)
    for key in ["target_drug", "replacement_drug", "monitoring_plan", "taper_days", "rationale_brief"]:
        coerced[key] = _normalize_scalar(coerced.get(key))
    if coerced.get("taper_days") is not None:
        coerced["taper_days"] = int(coerced["taper_days"])
    if "confidence" in coerced:
        coerced["confidence"] = float(coerced["confidence"])
    return coerced


def parse_decision_strict_xml(raw: str) -> DecisionSchema:
    """Strictly parse the required XML decision schema."""
    try:
        root = ET.fromstring(raw.strip())
    except Exception as exc:  # noqa: BLE001
        raise ParserError(f"Invalid XML: {exc}") from exc
    if root.tag != "decision":
        raise ParserError("XML decision root must be <decision>.")

    payload: dict[str, object] = {}
    for child in root:
        if child.tag not in _ALLOWED_XML_FIELDS:
            raise ParserError(f"Unknown XML field: {child.tag}")
        payload[child.tag] = child.text

    missing = sorted(_REQUIRED_XML_FIELDS - payload.keys())
    if missing:
        raise ParserError(f"Missing required XML fields: {missing}")

    return DecisionSchema.model_validate(_coerce_payload_types(payload))


def parse_decision_soft(raw: str) -> DecisionSchema:
    """Best-effort parsing across XML or JSON decision payloads."""
    stripped = raw.strip()
    if not stripped:
        raise ParserError("Empty decision payload.")

    if "<decision" in stripped:
        # Try exact parse first.
        try:
            return parse_decision_strict_xml(stripped)
        except ParserError:
            # Attempt to recover a wrapped/trailing buffer.
            start = stripped.find("<decision")
            end = stripped.rfind("</decision>")
            if start >= 0 and end > start:
                end += len("</decision>")
                return parse_decision_strict_xml(stripped[start:end])
            raise

    try:
        payload = json.loads(stripped)
    except Exception as exc:  # noqa: BLE001
        raise ParserError(f"Unable to parse JSON decision: {exc}") from exc
    return DecisionSchema.model_validate(_coerce_payload_types(payload))


def parse_decision_with_repair(raw: str) -> DecisionSchema:
    """Soft-parse and deterministically repair malformed decision payloads."""
    try:
        parsed = parse_decision_soft(raw)
        repaired = repair_partial_decision(parsed.model_dump(mode="json"))
        return DecisionSchema.model_validate(repaired)
    except Exception:
        stripped = raw.strip()
        payload: dict[str, object] = {}
        if stripped.startswith("{"):
            try:
                payload = json.loads(stripped)
            except Exception:  # noqa: BLE001
                payload = {}
        repaired = repair_partial_decision(payload)
        try:
            return DecisionSchema.model_validate(repaired)
        except Exception as exc:  # noqa: BLE001
            raise ParserError(f"Unable to parse decision after repair: {exc}") from exc


def parse_decision(raw: str) -> DecisionSchema:
    """Primary parse entrypoint (strict XML -> soft JSON/XML recovery)."""
    if raw.strip().startswith("<decision"):
        return parse_decision_strict_xml(raw)
    return parse_decision_soft(raw)