import torch import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin from open_clip import create_model_from_pretrained, get_tokenizer import os from open_clip_patch import patch_encode_text from timm_vit_return_attn_patch import patch_timm_vit_return_attn_scores from bert_modeling_bert_self_attn_patch import patch_bert_self_attn from loralib.utils import apply_lora from loss import CLIPLossACE_HGAT from PIL import Image import torch.nn.functional as F from prompt_templates import prompt_templates from torchmetrics.classification import BinaryAUROC, BinaryAccuracy import pandas as pd from tqdm import tqdm import pydicom from safetensors.torch import save_file, load_file def load_config_to_args(args_obj, config_dict): for key, value in config_dict.items(): setattr(args_obj, key, value) return args_obj class _Args: pass class ACE_LoRA_Model( nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/icon-lab/ACE-LoRA", pipeline_tag="zero-shot-classification", license="mit", ): def __init__(self, config: dict): super().__init__() self.config = config base_model_name: str = config.get("base_model_name", "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224") feature_dim: int = config.get("feature_dim", 512) self.context_length: int = config.get("context_length", 256) self.clip_model, self.preprocess = create_model_from_pretrained(base_model_name) self.tokenizer = get_tokenizer(base_model_name) patch_encode_text() patch_timm_vit_return_attn_scores() patch_bert_self_attn() args = _Args() load_config_to_args(args, config) self.lora_layers = apply_lora(args, self.clip_model) self.lora_params = nn.ParameterList([p for group in self.lora_layers for p in group.parameters()]) logit_scale = self.clip_model.state_dict()["logit_scale"].exp() self.loss_fn = CLIPLossACE_HGAT(args, logit_scale, feature_dim) self.logit_scale = nn.Parameter(self.clip_model.state_dict()["logit_scale"].clone(), requires_grad=False) def _save_pretrained(self, save_directory: str): os.makedirs(save_directory, exist_ok=True) payload = { **{k: v for k, v in self.clip_model.state_dict().items() if "lora" in k.lower()}, **{f"loss_fn.{k}": v for k, v in self.loss_fn.state_dict().items()}, "logit_scale": self.logit_scale.data, } payload = {k: v.contiguous() for k, v in payload.items()} save_file(payload, os.path.join(save_directory, "model.safetensors")) @classmethod def _from_pretrained(cls, *, model_id, revision=None, cache_dir=None, force_download=False, proxies=None, resume_download=False, local_files_only=False, token=None, map_location="cpu", strict=False, config=None, **kwargs): model = cls(config=config or {}) local_ckpt = os.path.join(model_id, "model.safetensors") if os.path.isfile(local_ckpt): ckpt_path = local_ckpt else: from huggingface_hub import hf_hub_download ckpt_path = hf_hub_download( repo_id=model_id, filename="model.safetensors", revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, token=token, ) state = load_file(ckpt_path, device=map_location) lora_state = {k: v for k, v in state.items() if "lora" in k.lower()} clip_sd = model.clip_model.state_dict() clip_sd.update(lora_state) model.clip_model.load_state_dict(clip_sd, strict=True) model.lora_params = nn.ParameterList([p for group in model.lora_layers for p in group.parameters()]) ace_state = {k.replace("loss_fn.", ""): v for k, v in state.items() if k.startswith("loss_fn.")} model.loss_fn.load_state_dict(ace_state, strict=True) if "logit_scale" in state: model.logit_scale.data.copy_(state["logit_scale"]) model.loss_fn.logit_scale.data.copy_(state["logit_scale"]) return model @staticmethod def _apply_ace_hgat(loss_fn, features, attn_weights, encoder="img"): if encoder == "img": edge_adapter = loss_fn.img_edge_adapter node_adapter = loss_fn.img_node_adapter elif encoder == "text": edge_adapter = loss_fn.text_edge_adapter node_adapter = loss_fn.text_node_adapter else: raise ValueError(f"encoder must be 'img' or 'text', got {encoder!r}") B, N, D = features.shape patches_norm = F.normalize(features[:, 1:, :], p=2, dim=-1) sim = torch.zeros(B, N, N, device=features.device) patch_sim = torch.bmm(patches_norm, patches_norm.transpose(1, 2)) sim[:, 1:, 1:] = patch_sim sim[:, 0, 1:] = attn_weights eye = torch.eye(N, device=features.device).bool().unsqueeze(0).repeat(B, 1, 1) mask = eye.clone() mask[:, 1:, 0] = True sim = sim.masked_fill(mask, float("-inf")) topk_vals, topk_idx = torch.topk(sim, k=5, dim=-1) sparse = torch.full_like(sim, float("-inf")) sparse.scatter_(-1, topk_idx, topk_vals) A = F.softmax(sparse, dim=-1) A = A.masked_fill(eye, 1.0) A[:, 1:, 0] = A[:, 0, 1:] H_edges = edge_adapter(torch.matmul(A, features)) H_context = node_adapter(torch.matmul(A.transpose(1, 2), H_edges)) return H_context @torch.no_grad() def encode_texts(self, class_names: list[str]) -> torch.Tensor: device = self.logit_scale.device feats = [] for name in class_names: tokens = self.tokenizer([t(name) for t in prompt_templates], context_length=self.context_length).to(device) feat, attn = self.clip_model.encode_text(tokens, normalize=True, output_attentions=True, output_tokens=True) feat = feat / feat.norm(dim=-1, keepdim=True) feat = feat.mean(dim=0) attn_w = attn[-1].mean(dim=1).mean(dim=0, keepdim=True)[:, 0, 1:] feat = self._apply_ace_hgat(self.loss_fn, feat.unsqueeze(0), attn_w, encoder="text") feat = F.normalize(feat, dim=-1) feats.append(feat) return torch.cat(feats, dim=0) @torch.no_grad() def encode_image(self, pil_image: Image.Image) -> torch.Tensor: device = self.logit_scale.device old_pool = self.clip_model.visual.trunk.global_pool self.clip_model.visual.trunk.global_pool = "" img_features, attn = self.clip_model.visual.trunk.get_attn_scores(self.preprocess(pil_image).unsqueeze(0).to(device)) img_features = F.normalize(self.clip_model.visual.head(img_features), dim=-1) attn_w = attn.mean(dim=1)[:, 0, 1:] img_features = self._apply_ace_hgat(self.loss_fn, img_features, attn_w, encoder="img") img_features = F.normalize(img_features, dim=-1) self.clip_model.visual.trunk.global_pool = old_pool return img_features def forward( self, image: Image.Image, class_names: list[str], ) -> torch.Tensor: logit_scale = self.logit_scale text_feats = self.encode_texts(class_names) image_feats = self.encode_image(image) logits = (logit_scale * image_feats[:, 0] @ text_feats[:, 0].t()) return logits.squeeze(0).softmax(dim=-1) if __name__ == "__main__": model = ACE_LoRA_Model.from_pretrained("aydnarda/ACE-LoRA", force_download=True) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") auc_metric = BinaryAUROC(thresholds=None) acc_metric = BinaryAccuracy().to(device) model = model.to(device) model.eval() TEST_CSV_PATH = './RSNA/test.csv' df = pd.read_csv(TEST_CSV_PATH) test_paths = df['Path'].tolist() classes = ['No Finding', 'pneumonia'] logits_list = [] label_list = [] for index in tqdm(range(len(df))): img_path = test_paths[index] img_data = pydicom.dcmread(img_path).pixel_array image = Image.fromarray(img_data) label = torch.zeros(len(classes), dtype=torch.int8, device=device) label[df['Target'][index]] = 1 pred = torch.zeros(len(classes), dtype=torch.int8, device=device) logits = model(image, classes).unsqueeze(0) logits_list.append(logits) label_list.append(label.argmax()) logits_all = torch.cat(logits_list, dim=0) # (N, C) labels_all = torch.stack(label_list) auc = auc_metric(logits_all[:, 1], labels_all) acc = acc_metric(logits_all[:, 1], labels_all) print("ACC: ", acc) print("AUC: ", auc)