|
|
| import torch, torch.nn.functional as F |
| import numpy as np |
| from PIL import Image |
| import pydicom |
| from pathlib import Path |
| from transformers import AutoModel, AutoProcessor |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| def _load_dicom(path: str) -> Image.Image: |
| ds = pydicom.dcmread(path) |
| arr = ds.pixel_array.astype(np.float32) |
| arr -= arr.min() |
| if arr.max() > 0: arr /= arr.max() |
| arr = (arr * 255).clip(0,255).astype(np.uint8) |
| return Image.fromarray(arr, mode="L").convert("RGB") |
|
|
| def _load_image(path: str) -> Image.Image: |
| p = str(path).lower() |
| return _load_dicom(path) if p.endswith(".dcm") else Image.open(path).convert("RGB") |
|
|
| @torch.no_grad() |
| def _extract_features(encoder, enc_proc, image: Image.Image): |
| inputs = enc_proc(images=image, text=["medical x-ray"], return_tensors="pt").to(DEVICE) |
| out = encoder(**inputs) |
| if hasattr(out, "image_embeds"): |
| emb = out.image_embeds |
| elif hasattr(out, "last_hidden_state"): |
| emb = out.last_hidden_state.mean(dim=1) |
| else: |
| tensors = [v for v in out.values() if isinstance(v, torch.Tensor)] |
| emb = torch.stack(tensors).mean(dim=0) |
| return F.normalize(emb, dim=-1) |
|
|
| def _emb_to_prompt(vec: torch.Tensor) -> str: |
| digest = vec[0, :8].detach().cpu().numpy() |
| summary = ", ".join([f"{v:+.2f}" for v in digest]) |
| return ( |
| "Radiology image features summarized by vector signature " |
| f"[{summary}]. Generate a concise, structured radiology report (Findings, Impression)." |
| ) |
|
|
| class MedFusionPipeline: |
| """ |
| Unified pipeline with two modes: |
| - mode='pro' -> FP16 + pruned models (higher accuracy) |
| - mode='lite' -> INT8 models (faster/smaller) |
| |
| Usage: |
| pipe = MedFusionPipeline.from_pretrained(".", mode="pro") |
| report = pipe.analyze("/path/to/xray_or_dicom") |
| """ |
| def __init__(self, root_dir: str, mode: str = "pro"): |
| self.root = Path(root_dir) |
| self.set_mode(mode) |
|
|
| @classmethod |
| def from_pretrained(cls, root_dir: str, mode: str = "pro"): |
| return cls(root_dir, mode) |
|
|
| def set_mode(self, mode: str): |
| assert mode in ("pro", "lite"), "mode must be 'pro' or 'lite'" |
| self.mode = mode |
| sub = "pro" if mode == "pro" else "lite" |
| enc_path = self.root / sub / "encoder" |
| dec_path = self.root / sub / "decoder" |
|
|
| |
| self.encoder = AutoModel.from_pretrained(enc_path, trust_remote_code=True).to(DEVICE).eval() |
| self.enc_proc = AutoProcessor.from_pretrained(enc_path) |
| self.decoder = AutoModel.from_pretrained(dec_path, trust_remote_code=True).to(DEVICE).eval() |
| self.dec_proc = AutoProcessor.from_pretrained(dec_path) |
|
|
| @torch.no_grad() |
| def analyze(self, path: str, max_new_tokens: int = 256): |
| img = _load_image(path) |
| emb = _extract_features(self.encoder, self.enc_proc, img) |
| prompt = _emb_to_prompt(emb) |
| inputs = self.dec_proc(images=img, text=[prompt], return_tensors="pt").to(DEVICE) |
| out = self.decoder.generate(**inputs, max_new_tokens=max_new_tokens) |
| txt = self.dec_proc.batch_decode(out, skip_special_tokens=True)[0] |
| return txt |
|
|