aydnarda commited on
Commit
1cc0e74
·
verified ·
1 Parent(s): 05a82cf

Upload hf_model_inference.py

Browse files
Files changed (1) hide show
  1. hf_model_inference.py +217 -0
hf_model_inference.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+ from open_clip import create_model_from_pretrained, get_tokenizer
5
+ import os
6
+ from open_clip_patch import patch_encode_text
7
+ from timm_vit_return_attn_patch import patch_timm_vit_return_attn_scores
8
+ from bert_modeling_bert_self_attn_patch import patch_bert_self_attn
9
+ from loralib.utils import apply_lora
10
+ from loss import CLIPLossACE_HGAT
11
+ from PIL import Image
12
+ import torch.nn.functional as F
13
+ from prompt_templates import prompt_templates
14
+ from torchmetrics.classification import BinaryAUROC, BinaryAccuracy
15
+ import pandas as pd
16
+ from tqdm import tqdm
17
+ import pydicom
18
+ from safetensors.torch import save_file, load_file
19
+
20
+ def load_config_to_args(args_obj, config_dict):
21
+ for key, value in config_dict.items():
22
+ setattr(args_obj, key, value)
23
+
24
+ return args_obj
25
+
26
+ class _Args:
27
+ pass
28
+
29
+ class ACE_LoRA_Model(
30
+ nn.Module,
31
+ PyTorchModelHubMixin,
32
+ repo_url="https://github.com/icon-lab/ACE-LoRA",
33
+ pipeline_tag="zero-shot-classification",
34
+ license="mit",
35
+ ):
36
+ def __init__(self, config: dict):
37
+ super().__init__()
38
+
39
+ self.config = config
40
+ base_model_name: str = config.get("base_model_name", "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224")
41
+ feature_dim: int = config.get("feature_dim", 512)
42
+ self.context_length: int = config.get("context_length", 256)
43
+
44
+ self.clip_model, self.preprocess = create_model_from_pretrained(base_model_name)
45
+ self.tokenizer = get_tokenizer(base_model_name)
46
+
47
+ patch_encode_text()
48
+ patch_timm_vit_return_attn_scores()
49
+ patch_bert_self_attn()
50
+ args = _Args()
51
+
52
+ load_config_to_args(args, config)
53
+ self.lora_layers = apply_lora(args, self.clip_model)
54
+ self.lora_params = nn.ParameterList([p for group in self.lora_layers for p in group.parameters()])
55
+ logit_scale = self.clip_model.state_dict()["logit_scale"].exp()
56
+ self.loss_fn = CLIPLossACE_HGAT(args, logit_scale, feature_dim)
57
+ self.logit_scale = nn.Parameter(self.clip_model.state_dict()["logit_scale"].clone(), requires_grad=False)
58
+
59
+ def _save_pretrained(self, save_directory: str):
60
+ os.makedirs(save_directory, exist_ok=True)
61
+ payload = {
62
+ **{k: v for k, v in self.clip_model.state_dict().items() if "lora" in k.lower()},
63
+ **{f"loss_fn.{k}": v for k, v in self.loss_fn.state_dict().items()},
64
+ "logit_scale": self.logit_scale.data,
65
+ }
66
+
67
+ payload = {k: v.contiguous() for k, v in payload.items()}
68
+ save_file(payload, os.path.join(save_directory, "model.safetensors"))
69
+
70
+ @classmethod
71
+ def _from_pretrained(cls, *, model_id, revision=None, cache_dir=None,
72
+ force_download=False, proxies=None, resume_download=False,
73
+ local_files_only=False, token=None, map_location="cpu",
74
+ strict=False, config=None, **kwargs):
75
+
76
+ model = cls(config=config or {})
77
+
78
+ local_ckpt = os.path.join(model_id, "model.safetensors")
79
+ if os.path.isfile(local_ckpt):
80
+ ckpt_path = local_ckpt
81
+ else:
82
+ from huggingface_hub import hf_hub_download
83
+ ckpt_path = hf_hub_download(
84
+ repo_id=model_id, filename="model.safetensors",
85
+ revision=revision, cache_dir=cache_dir,
86
+ force_download=force_download, proxies=proxies,
87
+ resume_download=resume_download,
88
+ local_files_only=local_files_only, token=token,
89
+ )
90
+
91
+ state = load_file(ckpt_path, device=map_location)
92
+ lora_state = {k: v for k, v in state.items() if "lora" in k.lower()}
93
+ clip_sd = model.clip_model.state_dict()
94
+ clip_sd.update(lora_state)
95
+ model.clip_model.load_state_dict(clip_sd, strict=True)
96
+ model.lora_params = nn.ParameterList([p for group in model.lora_layers for p in group.parameters()])
97
+
98
+ ace_state = {k.replace("loss_fn.", ""): v for k, v in state.items() if k.startswith("loss_fn.")}
99
+ model.loss_fn.load_state_dict(ace_state, strict=True)
100
+
101
+ if "logit_scale" in state:
102
+ model.logit_scale.data.copy_(state["logit_scale"])
103
+ model.loss_fn.logit_scale.data.copy_(state["logit_scale"])
104
+
105
+ return model
106
+
107
+ @staticmethod
108
+ def _apply_ace_hgat(loss_fn, features, attn_weights, encoder="img"):
109
+ if encoder == "img":
110
+ edge_adapter = loss_fn.img_edge_adapter
111
+ node_adapter = loss_fn.img_node_adapter
112
+ elif encoder == "text":
113
+ edge_adapter = loss_fn.text_edge_adapter
114
+ node_adapter = loss_fn.text_node_adapter
115
+ else:
116
+ raise ValueError(f"encoder must be 'img' or 'text', got {encoder!r}")
117
+
118
+ B, N, D = features.shape
119
+ patches_norm = F.normalize(features[:, 1:, :], p=2, dim=-1)
120
+ sim = torch.zeros(B, N, N, device=features.device)
121
+ patch_sim = torch.bmm(patches_norm, patches_norm.transpose(1, 2))
122
+ sim[:, 1:, 1:] = patch_sim
123
+ sim[:, 0, 1:] = attn_weights
124
+ eye = torch.eye(N, device=features.device).bool().unsqueeze(0).repeat(B, 1, 1)
125
+ mask = eye.clone()
126
+ mask[:, 1:, 0] = True
127
+ sim = sim.masked_fill(mask, float("-inf"))
128
+
129
+ topk_vals, topk_idx = torch.topk(sim, k=5, dim=-1)
130
+ sparse = torch.full_like(sim, float("-inf"))
131
+ sparse.scatter_(-1, topk_idx, topk_vals)
132
+ A = F.softmax(sparse, dim=-1)
133
+ A = A.masked_fill(eye, 1.0)
134
+ A[:, 1:, 0] = A[:, 0, 1:]
135
+ H_edges = edge_adapter(torch.matmul(A, features))
136
+ H_context = node_adapter(torch.matmul(A.transpose(1, 2), H_edges))
137
+ return H_context
138
+
139
+ @torch.no_grad()
140
+ def encode_texts(self, class_names: list[str]) -> torch.Tensor:
141
+ device = self.logit_scale.device
142
+ feats = []
143
+
144
+ for name in class_names:
145
+ tokens = self.tokenizer([t(name) for t in prompt_templates], context_length=self.context_length).to(device)
146
+ feat, attn = self.clip_model.encode_text(tokens, normalize=True, output_attentions=True, output_tokens=True)
147
+ feat = feat / feat.norm(dim=-1, keepdim=True)
148
+ feat = feat.mean(dim=0)
149
+
150
+ attn_w = attn[-1].mean(dim=1).mean(dim=0, keepdim=True)[:, 0, 1:]
151
+ feat = self._apply_ace_hgat(self.loss_fn, feat.unsqueeze(0), attn_w, encoder="text")
152
+ feat = F.normalize(feat, dim=-1)
153
+ feats.append(feat)
154
+
155
+ return torch.cat(feats, dim=0)
156
+
157
+ @torch.no_grad()
158
+ def encode_image(self, pil_image: Image.Image) -> torch.Tensor:
159
+ device = self.logit_scale.device
160
+ old_pool = self.clip_model.visual.trunk.global_pool
161
+ self.clip_model.visual.trunk.global_pool = ""
162
+
163
+ img_features, attn = self.clip_model.visual.trunk.get_attn_scores(self.preprocess(pil_image).unsqueeze(0).to(device))
164
+ img_features = F.normalize(self.clip_model.visual.head(img_features), dim=-1)
165
+ attn_w = attn.mean(dim=1)[:, 0, 1:]
166
+ img_features = self._apply_ace_hgat(self.loss_fn, img_features, attn_w, encoder="img")
167
+ img_features = F.normalize(img_features, dim=-1)
168
+ self.clip_model.visual.trunk.global_pool = old_pool
169
+ return img_features
170
+
171
+ def forward(
172
+ self,
173
+ image: Image.Image,
174
+ class_names: list[str],
175
+ ) -> torch.Tensor:
176
+ logit_scale = self.logit_scale
177
+ text_feats = self.encode_texts(class_names)
178
+ image_feats = self.encode_image(image)
179
+
180
+ logits = (logit_scale * image_feats[:, 0] @ text_feats[:, 0].t())
181
+ return logits.squeeze(0).softmax(dim=-1)
182
+
183
+ if __name__ == "__main__":
184
+
185
+ model = ACE_LoRA_Model.from_pretrained("aydnarda/ACE-LoRA", force_download=True)
186
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
187
+ auc_metric = BinaryAUROC(thresholds=None)
188
+ acc_metric = BinaryAccuracy().to(device)
189
+ model = model.to(device)
190
+ model.eval()
191
+
192
+ TEST_CSV_PATH = './RSNA/test.csv'
193
+ df = pd.read_csv(TEST_CSV_PATH)
194
+ test_paths = df['Path'].tolist()
195
+ classes = ['No Finding', 'pneumonia']
196
+ logits_list = []
197
+ label_list = []
198
+
199
+ for index in tqdm(range(len(df))):
200
+ img_path = test_paths[index]
201
+ img_data = pydicom.dcmread(img_path).pixel_array
202
+ image = Image.fromarray(img_data)
203
+
204
+ label = torch.zeros(len(classes), dtype=torch.int8, device=device)
205
+ label[df['Target'][index]] = 1
206
+ pred = torch.zeros(len(classes), dtype=torch.int8, device=device)
207
+ logits = model(image, classes).unsqueeze(0)
208
+ logits_list.append(logits)
209
+ label_list.append(label.argmax())
210
+
211
+ logits_all = torch.cat(logits_list, dim=0) # (N, C)
212
+ labels_all = torch.stack(label_list)
213
+ auc = auc_metric(logits_all[:, 1], labels_all)
214
+ acc = acc_metric(logits_all[:, 1], labels_all)
215
+
216
+ print("ACC: ", acc)
217
+ print("AUC: ", auc)