MedFusion-AI / medfusion_pipeline.py
fokan's picture
Add unified MedFusion-AI (Pro+Lite) with pipeline & app
b85ba72 verified
import torch, torch.nn.functional as F
import numpy as np
from PIL import Image
import pydicom
from pathlib import Path
from transformers import AutoModel, AutoProcessor
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def _load_dicom(path: str) -> Image.Image:
ds = pydicom.dcmread(path)
arr = ds.pixel_array.astype(np.float32)
arr -= arr.min()
if arr.max() > 0: arr /= arr.max()
arr = (arr * 255).clip(0,255).astype(np.uint8)
return Image.fromarray(arr, mode="L").convert("RGB")
def _load_image(path: str) -> Image.Image:
p = str(path).lower()
return _load_dicom(path) if p.endswith(".dcm") else Image.open(path).convert("RGB")
@torch.no_grad()
def _extract_features(encoder, enc_proc, image: Image.Image):
inputs = enc_proc(images=image, text=["medical x-ray"], return_tensors="pt").to(DEVICE)
out = encoder(**inputs)
if hasattr(out, "image_embeds"):
emb = out.image_embeds
elif hasattr(out, "last_hidden_state"):
emb = out.last_hidden_state.mean(dim=1)
else:
tensors = [v for v in out.values() if isinstance(v, torch.Tensor)]
emb = torch.stack(tensors).mean(dim=0)
return F.normalize(emb, dim=-1)
def _emb_to_prompt(vec: torch.Tensor) -> str:
digest = vec[0, :8].detach().cpu().numpy()
summary = ", ".join([f"{v:+.2f}" for v in digest])
return (
"Radiology image features summarized by vector signature "
f"[{summary}]. Generate a concise, structured radiology report (Findings, Impression)."
)
class MedFusionPipeline:
"""
Unified pipeline with two modes:
- mode='pro' -> FP16 + pruned models (higher accuracy)
- mode='lite' -> INT8 models (faster/smaller)
Usage:
pipe = MedFusionPipeline.from_pretrained(".", mode="pro")
report = pipe.analyze("/path/to/xray_or_dicom")
"""
def __init__(self, root_dir: str, mode: str = "pro"):
self.root = Path(root_dir)
self.set_mode(mode)
@classmethod
def from_pretrained(cls, root_dir: str, mode: str = "pro"):
return cls(root_dir, mode)
def set_mode(self, mode: str):
assert mode in ("pro", "lite"), "mode must be 'pro' or 'lite'"
self.mode = mode
sub = "pro" if mode == "pro" else "lite"
enc_path = self.root / sub / "encoder"
dec_path = self.root / sub / "decoder"
# Lazy load
self.encoder = AutoModel.from_pretrained(enc_path, trust_remote_code=True).to(DEVICE).eval()
self.enc_proc = AutoProcessor.from_pretrained(enc_path)
self.decoder = AutoModel.from_pretrained(dec_path, trust_remote_code=True).to(DEVICE).eval()
self.dec_proc = AutoProcessor.from_pretrained(dec_path)
@torch.no_grad()
def analyze(self, path: str, max_new_tokens: int = 256):
img = _load_image(path)
emb = _extract_features(self.encoder, self.enc_proc, img)
prompt = _emb_to_prompt(emb)
inputs = self.dec_proc(images=img, text=[prompt], return_tensors="pt").to(DEVICE)
out = self.decoder.generate(**inputs, max_new_tokens=max_new_tokens)
txt = self.dec_proc.batch_decode(out, skip_special_tokens=True)[0]
return txt