File size: 8,942 Bytes
21ccfaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b081d1
21ccfaf
 
 
 
 
 
 
1b5f5d8
 
 
21ccfaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b081d1
21ccfaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b5f5d8
 
 
21ccfaf
 
 
 
 
 
1b5f5d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21ccfaf
1b5f5d8
 
 
 
 
 
21ccfaf
1b5f5d8
 
 
 
7b081d1
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""WoundNetB7 End-to-End Pipeline: Image -> Segmentation -> PWAT + Fitzpatrick/ITA.

Usage:
    from pipeline import WoundNetB7Pipeline
    pipe = WoundNetB7Pipeline("models/")
    result = pipe.analyze("path/to/dfu_image.png")
    print(result)
"""
import torch
import numpy as np
import cv2
from pathlib import Path
from dataclasses import dataclass, field, asdict
from typing import Optional

from src.segmentation import load_segmentation_model, segment, CLASS_NAMES, CLASS_COLORS
from src.fitzpatrick_estimator import estimate_fitzpatrick, FitzpatrickResult
from src.pwat_estimator import PWATPredictor, PWATResult, ITEM_NAMES
from src.integrated_report import render_integrated_report


@dataclass
class AnalysisResult:
    """Complete DFU analysis result from a single image."""
    # Segmentation
    class_distribution: dict = field(default_factory=dict)  # {class_name: percentage}
    classmap: Optional[np.ndarray] = field(default=None, repr=False)  # (H,W) uint8
    probs: Optional[np.ndarray] = field(default=None, repr=False)     # (4,H,W) float32
    ulcer_mask: Optional[np.ndarray] = field(default=None, repr=False)  # (H,W) uint8
    # Fitzpatrick
    fitzpatrick: Optional[FitzpatrickResult] = None
    # PWAT
    pwat: Optional[PWATResult] = None
    # Metadata
    image_size: tuple = (0, 0)
    device: str = "cpu"

    def summary(self) -> str:
        lines = ["=" * 50, "WoundNetB7 DFU Analysis", "=" * 50]
        lines.append(f"Image: {self.image_size[1]}x{self.image_size[0]}")
        lines.append(f"Device: {self.device}")

        lines.append("\n--- Segmentation ---")
        for cls, pct in self.class_distribution.items():
            lines.append(f"  {cls:<15s}: {pct:5.1f}%")

        if self.fitzpatrick:
            f = self.fitzpatrick
            lines.append("\n--- Fitzpatrick / ITA ---")
            lines.append(f"  Type:       {f.fitzpatrick_type} ({f.fitzpatrick_label})")
            lines.append(f"  ITA:        {f.ita_angle:.1f} +/- {f.ita_std:.1f}")
            lines.append(f"  L* mean:    {f.l_skin_mean:.1f}")
            lines.append(f"  Confidence: {f.confidence:.2f}")
            lines.append(f"  Pixels:     {f.healthy_pixels:,}")

        if self.pwat and self.pwat.scores_raw:
            p = self.pwat
            lines.append("\n--- PWAT Scores (Items 3-8) ---")
            lines.append(f"  {'Item':<22s} {'Raw':>4s} {'Adj':>5s}")
            lines.append("  " + "-" * 33)
            for item in [3, 4, 5, 6, 7, 8]:
                name = ITEM_NAMES.get(item, f"Item {item}")
                raw = p.scores_raw.get(item, "-")
                adj = p.scores_adjusted.get(item, "-")
                lines.append(f"  {name:<22s} {raw:>4}  {adj:>5.1f}")
            lines.append(f"  {'TOTAL':<22s} {p.total_raw:>4}  {p.total_adjusted:>5.1f}")
            lines.append(f"  Fitzpatrick correction applied for type: {p.fitzpatrick_type}")

        return "\n".join(lines)

    def to_dict(self) -> dict:
        d = {"image_size": self.image_size, "device": self.device, "class_distribution": self.class_distribution}
        if self.fitzpatrick:
            d["fitzpatrick"] = asdict(self.fitzpatrick)
        if self.pwat:
            d["pwat"] = {
                "scores_raw": self.pwat.scores_raw,
                "scores_adjusted": self.pwat.scores_adjusted,
                "total_raw": self.pwat.total_raw,
                "total_adjusted": self.pwat.total_adjusted,
            }
        return d


class WoundNetB7Pipeline:
    """End-to-end DFU analysis pipeline.

    Args:
        models_dir: Path to models/ directory containing:
            - segmentation/WoundNetB7_proposed_best.pt
            - pwat/xgb_pwat{3-8}.pkl
        device: "cuda" or "cpu" (auto-detected if None)
        use_tta: Use test-time augmentation for segmentation (slower but more accurate)
    """

    def __init__(self, models_dir: str = "models", device: Optional[str] = None, use_tta: bool = True):
        self.models_dir = Path(models_dir)
        self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
        self.use_tta = use_tta

        # Load segmentation model
        seg_path = self.models_dir / "segmentation" / "WoundNetB7_proposed_best.pt"
        self.seg_model = load_segmentation_model(str(seg_path), self.device)
        print(f"Segmentation model loaded ({sum(p.numel() for p in self.seg_model.parameters()) / 1e6:.1f}M params)")

        # Load PWAT predictor
        pwat_path = self.models_dir / "pwat"
        self.pwat_predictor = PWATPredictor(str(pwat_path))
        print(f"PWAT models loaded ({len(self.pwat_predictor.models)} items)")

        print(f"Pipeline ready on {self.device}")

    def analyze(self, image_input, use_tta: Optional[bool] = None) -> AnalysisResult:
        """Analyze a DFU image end-to-end.

        Args:
            image_input: file path (str/Path), BGR numpy array, or RGB numpy array
            use_tta: Override TTA setting for this call

        Returns:
            AnalysisResult with segmentation, Fitzpatrick, and PWAT data.
        """
        tta = use_tta if use_tta is not None else self.use_tta

        # Load image
        if isinstance(image_input, (str, Path)):
            img_bgr = cv2.imread(str(image_input))
            if img_bgr is None:
                raise FileNotFoundError(f"Cannot read image: {image_input}")
        elif isinstance(image_input, np.ndarray):
            img_bgr = image_input if image_input.shape[2] == 3 else cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR)
        else:
            raise TypeError(f"Unsupported input type: {type(image_input)}")

        h, w = img_bgr.shape[:2]

        # Step 1: Segmentation
        seg = segment(self.seg_model, img_bgr, self.device, use_tta=tta)
        classmap = seg["classmap"]

        class_dist = {}
        for cid, name in CLASS_NAMES.items():
            class_dist[name] = round(float(np.mean(classmap == cid) * 100), 1)

        # Step 2: Fitzpatrick estimation (from healthy skin)
        fitz = estimate_fitzpatrick(img_bgr, seg["masks"])

        # Step 3: PWAT prediction (from ulcer mask)
        ulcer_mask = (classmap == 3).astype(np.uint8) * 255
        pwat = self.pwat_predictor.predict(img_bgr, ulcer_mask, fitzpatrick_type=fitz.fitzpatrick_type)

        return AnalysisResult(
            class_distribution=class_dist,
            classmap=classmap,
            probs=seg["probs"],
            ulcer_mask=ulcer_mask,
            fitzpatrick=fitz,
            pwat=pwat,
            image_size=(h, w),
            device=str(self.device),
        )

    def visualize_binary(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray:
        """Create binary ulcer segmentation overlay (ulcer only, red mask)."""
        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
        overlay = img_rgb.copy()
        if result.ulcer_mask is not None:
            ulcer_bool = result.ulcer_mask > 127
            if np.any(ulcer_bool):
                overlay[ulcer_bool] = overlay[ulcer_bool] * 0.4 + np.array([255, 0, 0], dtype=np.float32) * 0.6
                overlay_u8 = np.clip(overlay, 0, 255).astype(np.uint8)
                contours, _ = cv2.findContours(
                    ulcer_bool.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
                )
                cv2.drawContours(overlay_u8, contours, -1, (255, 255, 255), 2)
                return overlay_u8
        return np.clip(overlay, 0, 255).astype(np.uint8)

    def visualize_multiclass(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray:
        """Create multiclass segmentation overlay using cached classmap."""
        overlay = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
        if result.classmap is not None:
            for cid, color in CLASS_COLORS.items():
                if cid == 0:
                    continue
                mask = result.classmap == cid
                overlay[mask] = overlay[mask] * 0.5 + np.array(color, dtype=np.float32) * 0.5
        return overlay.astype(np.uint8)

    def visualize(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray:
        """Create overlay visualization (backward compatible)."""
        return self.visualize_multiclass(img_bgr, result)

    def render_integrated_report(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray:
        """Render a single-image integrated clinical dashboard (1920x1200 RGB).

        Combines segmentation, class distribution, Fitzpatrick/ITA estimation
        and PWAT scoring (raw + adjusted) into one nurse-facing report.
        """
        original_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        binary_overlay = self.visualize_binary(img_bgr, result)
        multi_overlay = self.visualize_multiclass(img_bgr, result)
        return render_integrated_report(original_rgb, binary_overlay, multi_overlay, result)