import torch, torch.nn.functional as F import numpy as np from PIL import Image import pydicom from pathlib import Path from transformers import AutoModel, AutoProcessor DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def _load_dicom(path: str) -> Image.Image: ds = pydicom.dcmread(path) arr = ds.pixel_array.astype(np.float32) arr -= arr.min() if arr.max() > 0: arr /= arr.max() arr = (arr * 255).clip(0,255).astype(np.uint8) return Image.fromarray(arr, mode="L").convert("RGB") def _load_image(path: str) -> Image.Image: p = str(path).lower() return _load_dicom(path) if p.endswith(".dcm") else Image.open(path).convert("RGB") @torch.no_grad() def _extract_features(encoder, enc_proc, image: Image.Image): inputs = enc_proc(images=image, text=["medical x-ray"], return_tensors="pt").to(DEVICE) out = encoder(**inputs) if hasattr(out, "image_embeds"): emb = out.image_embeds elif hasattr(out, "last_hidden_state"): emb = out.last_hidden_state.mean(dim=1) else: tensors = [v for v in out.values() if isinstance(v, torch.Tensor)] emb = torch.stack(tensors).mean(dim=0) return F.normalize(emb, dim=-1) def _emb_to_prompt(vec: torch.Tensor) -> str: digest = vec[0, :8].detach().cpu().numpy() summary = ", ".join([f"{v:+.2f}" for v in digest]) return ( "Radiology image features summarized by vector signature " f"[{summary}]. Generate a concise, structured radiology report (Findings, Impression)." ) class MedFusionPipeline: """ Unified pipeline with two modes: - mode='pro' -> FP16 + pruned models (higher accuracy) - mode='lite' -> INT8 models (faster/smaller) Usage: pipe = MedFusionPipeline.from_pretrained(".", mode="pro") report = pipe.analyze("/path/to/xray_or_dicom") """ def __init__(self, root_dir: str, mode: str = "pro"): self.root = Path(root_dir) self.set_mode(mode) @classmethod def from_pretrained(cls, root_dir: str, mode: str = "pro"): return cls(root_dir, mode) def set_mode(self, mode: str): assert mode in ("pro", "lite"), "mode must be 'pro' or 'lite'" self.mode = mode sub = "pro" if mode == "pro" else "lite" enc_path = self.root / sub / "encoder" dec_path = self.root / sub / "decoder" # Lazy load self.encoder = AutoModel.from_pretrained(enc_path, trust_remote_code=True).to(DEVICE).eval() self.enc_proc = AutoProcessor.from_pretrained(enc_path) self.decoder = AutoModel.from_pretrained(dec_path, trust_remote_code=True).to(DEVICE).eval() self.dec_proc = AutoProcessor.from_pretrained(dec_path) @torch.no_grad() def analyze(self, path: str, max_new_tokens: int = 256): img = _load_image(path) emb = _extract_features(self.encoder, self.enc_proc, img) prompt = _emb_to_prompt(emb) inputs = self.dec_proc(images=img, text=[prompt], return_tensors="pt").to(DEVICE) out = self.decoder.generate(**inputs, max_new_tokens=max_new_tokens) txt = self.dec_proc.batch_decode(out, skip_special_tokens=True)[0] return txt