| import torch |
| import torch.nn as nn |
| from huggingface_hub import PyTorchModelHubMixin |
| from open_clip import create_model_from_pretrained, get_tokenizer |
| import os |
| from open_clip_patch import patch_encode_text |
| from timm_vit_return_attn_patch import patch_timm_vit_return_attn_scores |
| from bert_modeling_bert_self_attn_patch import patch_bert_self_attn |
| from loralib.utils import apply_lora |
| from loss import CLIPLossACE_HGAT |
| from PIL import Image |
| import torch.nn.functional as F |
| from prompt_templates import prompt_templates |
| from torchmetrics.classification import BinaryAUROC, BinaryAccuracy |
| import pandas as pd |
| from tqdm import tqdm |
| import pydicom |
| from safetensors.torch import save_file, load_file |
|
|
| def load_config_to_args(args_obj, config_dict): |
| for key, value in config_dict.items(): |
| setattr(args_obj, key, value) |
| |
| return args_obj |
|
|
| class _Args: |
| pass |
| |
| class ACE_LoRA_Model( |
| nn.Module, |
| PyTorchModelHubMixin, |
| repo_url="https://github.com/icon-lab/ACE-LoRA", |
| pipeline_tag="zero-shot-classification", |
| license="mit", |
| ): |
| def __init__(self, config: dict): |
| super().__init__() |
| |
| self.config = config |
| base_model_name: str = config.get("base_model_name", "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224") |
| feature_dim: int = config.get("feature_dim", 512) |
| self.context_length: int = config.get("context_length", 256) |
| |
| self.clip_model, self.preprocess = create_model_from_pretrained(base_model_name) |
| self.tokenizer = get_tokenizer(base_model_name) |
| |
| patch_encode_text() |
| patch_timm_vit_return_attn_scores() |
| patch_bert_self_attn() |
| args = _Args() |
| |
| load_config_to_args(args, config) |
| self.lora_layers = apply_lora(args, self.clip_model) |
| self.lora_params = nn.ParameterList([p for group in self.lora_layers for p in group.parameters()]) |
| logit_scale = self.clip_model.state_dict()["logit_scale"].exp() |
| self.loss_fn = CLIPLossACE_HGAT(args, logit_scale, feature_dim) |
| self.logit_scale = nn.Parameter(self.clip_model.state_dict()["logit_scale"].clone(), requires_grad=False) |
| |
| def _save_pretrained(self, save_directory: str): |
| os.makedirs(save_directory, exist_ok=True) |
| payload = { |
| **{k: v for k, v in self.clip_model.state_dict().items() if "lora" in k.lower()}, |
| **{f"loss_fn.{k}": v for k, v in self.loss_fn.state_dict().items()}, |
| "logit_scale": self.logit_scale.data, |
| } |
| |
| payload = {k: v.contiguous() for k, v in payload.items()} |
| save_file(payload, os.path.join(save_directory, "model.safetensors")) |
|
|
| @classmethod |
| def _from_pretrained(cls, *, model_id, revision=None, cache_dir=None, |
| force_download=False, proxies=None, resume_download=False, |
| local_files_only=False, token=None, map_location="cpu", |
| strict=False, config=None, **kwargs): |
|
|
| model = cls(config=config or {}) |
|
|
| local_ckpt = os.path.join(model_id, "model.safetensors") |
| if os.path.isfile(local_ckpt): |
| ckpt_path = local_ckpt |
| else: |
| from huggingface_hub import hf_hub_download |
| ckpt_path = hf_hub_download( |
| repo_id=model_id, filename="model.safetensors", |
| revision=revision, cache_dir=cache_dir, |
| force_download=force_download, proxies=proxies, |
| resume_download=resume_download, |
| local_files_only=local_files_only, token=token, |
| ) |
|
|
| state = load_file(ckpt_path, device=map_location) |
| lora_state = {k: v for k, v in state.items() if "lora" in k.lower()} |
| clip_sd = model.clip_model.state_dict() |
| clip_sd.update(lora_state) |
| model.clip_model.load_state_dict(clip_sd, strict=True) |
| model.lora_params = nn.ParameterList([p for group in model.lora_layers for p in group.parameters()]) |
|
|
| ace_state = {k.replace("loss_fn.", ""): v for k, v in state.items() if k.startswith("loss_fn.")} |
| model.loss_fn.load_state_dict(ace_state, strict=True) |
|
|
| if "logit_scale" in state: |
| model.logit_scale.data.copy_(state["logit_scale"]) |
| model.loss_fn.logit_scale.data.copy_(state["logit_scale"]) |
|
|
| return model |
| |
| @staticmethod |
| def _apply_ace_hgat(loss_fn, features, attn_weights, encoder="img"): |
| if encoder == "img": |
| edge_adapter = loss_fn.img_edge_adapter |
| node_adapter = loss_fn.img_node_adapter |
| elif encoder == "text": |
| edge_adapter = loss_fn.text_edge_adapter |
| node_adapter = loss_fn.text_node_adapter |
| else: |
| raise ValueError(f"encoder must be 'img' or 'text', got {encoder!r}") |
| |
| B, N, D = features.shape |
| patches_norm = F.normalize(features[:, 1:, :], p=2, dim=-1) |
| sim = torch.zeros(B, N, N, device=features.device) |
| patch_sim = torch.bmm(patches_norm, patches_norm.transpose(1, 2)) |
| sim[:, 1:, 1:] = patch_sim |
| sim[:, 0, 1:] = attn_weights |
| eye = torch.eye(N, device=features.device).bool().unsqueeze(0).repeat(B, 1, 1) |
| mask = eye.clone() |
| mask[:, 1:, 0] = True |
| sim = sim.masked_fill(mask, float("-inf")) |
| |
| topk_vals, topk_idx = torch.topk(sim, k=5, dim=-1) |
| sparse = torch.full_like(sim, float("-inf")) |
| sparse.scatter_(-1, topk_idx, topk_vals) |
| A = F.softmax(sparse, dim=-1) |
| A = A.masked_fill(eye, 1.0) |
| A[:, 1:, 0] = A[:, 0, 1:] |
| H_edges = edge_adapter(torch.matmul(A, features)) |
| H_context = node_adapter(torch.matmul(A.transpose(1, 2), H_edges)) |
| return H_context |
| |
| @torch.no_grad() |
| def encode_texts(self, class_names: list[str]) -> torch.Tensor: |
| device = self.logit_scale.device |
| feats = [] |
| |
| for name in class_names: |
| tokens = self.tokenizer([t(name) for t in prompt_templates], context_length=self.context_length).to(device) |
| feat, attn = self.clip_model.encode_text(tokens, normalize=True, output_attentions=True, output_tokens=True) |
| feat = feat / feat.norm(dim=-1, keepdim=True) |
| feat = feat.mean(dim=0) |
| |
| attn_w = attn[-1].mean(dim=1).mean(dim=0, keepdim=True)[:, 0, 1:] |
| feat = self._apply_ace_hgat(self.loss_fn, feat.unsqueeze(0), attn_w, encoder="text") |
| feat = F.normalize(feat, dim=-1) |
| feats.append(feat) |
| |
| return torch.cat(feats, dim=0) |
| |
| @torch.no_grad() |
| def encode_image(self, pil_image: Image.Image) -> torch.Tensor: |
| device = self.logit_scale.device |
| old_pool = self.clip_model.visual.trunk.global_pool |
| self.clip_model.visual.trunk.global_pool = "" |
| |
| img_features, attn = self.clip_model.visual.trunk.get_attn_scores(self.preprocess(pil_image).unsqueeze(0).to(device)) |
| img_features = F.normalize(self.clip_model.visual.head(img_features), dim=-1) |
| attn_w = attn.mean(dim=1)[:, 0, 1:] |
| img_features = self._apply_ace_hgat(self.loss_fn, img_features, attn_w, encoder="img") |
| img_features = F.normalize(img_features, dim=-1) |
| self.clip_model.visual.trunk.global_pool = old_pool |
| return img_features |
| |
| def forward( |
| self, |
| image: Image.Image, |
| class_names: list[str], |
| ) -> torch.Tensor: |
| logit_scale = self.logit_scale |
| text_feats = self.encode_texts(class_names) |
| image_feats = self.encode_image(image) |
| |
| logits = (logit_scale * image_feats[:, 0] @ text_feats[:, 0].t()) |
| return logits.squeeze(0).softmax(dim=-1) |
| |
| if __name__ == "__main__": |
|
|
| model = ACE_LoRA_Model.from_pretrained("aydnarda/ACE-LoRA", force_download=True) |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| auc_metric = BinaryAUROC(thresholds=None) |
| acc_metric = BinaryAccuracy().to(device) |
| model = model.to(device) |
| model.eval() |
| |
| TEST_CSV_PATH = './RSNA/test.csv' |
| df = pd.read_csv(TEST_CSV_PATH) |
| test_paths = df['Path'].tolist() |
| classes = ['No Finding', 'pneumonia'] |
| logits_list = [] |
| label_list = [] |
|
|
| for index in tqdm(range(len(df))): |
| img_path = test_paths[index] |
| img_data = pydicom.dcmread(img_path).pixel_array |
| image = Image.fromarray(img_data) |
|
|
| label = torch.zeros(len(classes), dtype=torch.int8, device=device) |
| label[df['Target'][index]] = 1 |
| pred = torch.zeros(len(classes), dtype=torch.int8, device=device) |
| logits = model(image, classes).unsqueeze(0) |
| logits_list.append(logits) |
| label_list.append(label.argmax()) |
|
|
| logits_all = torch.cat(logits_list, dim=0) |
| labels_all = torch.stack(label_list) |
| auc = auc_metric(logits_all[:, 1], labels_all) |
| acc = acc_metric(logits_all[:, 1], labels_all) |
|
|
| print("ACC: ", acc) |
| print("AUC: ", auc) |