qkenr0804 commited on
Commit
e90e75c
·
verified ·
1 Parent(s): 8150b71

Upload 29 files

Browse files
Files changed (30) hide show
  1. .gitattributes +1 -0
  2. GOAL_github/LISM/DCIsamwithclipmaxsimmaintainbackrefsentencefortestset.py +170 -0
  3. GOAL_github/LISM/DCIsamwithclipmaxsimmaintainbackrefsentencefortrainset.py +170 -0
  4. GOAL_github/datasets/DCI_segment_only_sim_max_del_org.json +0 -0
  5. GOAL_github/datasets/DCI_test.json +0 -0
  6. GOAL_github/datasets/DCI_test_joint_sim_max_1 +0 -0
  7. GOAL_github/datasets/DCI_train_del_org.json +0 -0
  8. GOAL_github/datasets/docci_segment_sim_bbox_del_org.json +3 -0
  9. GOAL_github/datasets/docci_test.json +0 -0
  10. GOAL_github/datasets/docci_test_joint_sim_max_1 +0 -0
  11. GOAL_github/datasets/docci_train_del_org.json +0 -0
  12. GOAL_github/datasets/urban_dataset_test.json +0 -0
  13. GOAL_github/goal.py +469 -0
  14. GOAL_github/mAP_goal_jointtest.py +256 -0
  15. GOAL_github/retrieval_goal.py +171 -0
  16. GOAL_github/utils/__pycache__/easydict.cpython-39.pyc +0 -0
  17. GOAL_github/utils/__pycache__/func.cpython-310.pyc +0 -0
  18. GOAL_github/utils/__pycache__/func.cpython-311.pyc +0 -0
  19. GOAL_github/utils/__pycache__/func.cpython-39.pyc +0 -0
  20. GOAL_github/utils/__pycache__/randaugment.cpython-310.pyc +0 -0
  21. GOAL_github/utils/__pycache__/randaugment.cpython-311.pyc +0 -0
  22. GOAL_github/utils/__pycache__/randaugment.cpython-39.pyc +0 -0
  23. GOAL_github/utils/__pycache__/transforms.cpython-310.pyc +0 -0
  24. GOAL_github/utils/__pycache__/transforms.cpython-311.pyc +0 -0
  25. GOAL_github/utils/__pycache__/transforms.cpython-39.pyc +0 -0
  26. GOAL_github/utils/func.py +106 -0
  27. GOAL_github/utils/randaugment.py +349 -0
  28. GOAL_github/utils/transforms.py +130 -0
  29. GOAL_github/visualization/visualization_attentionmap_longtestset.py +188 -0
  30. GOAL_github/visualization/visualization_retreival.py +313 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ GOAL_github/datasets/docci_segment_sim_bbox_del_org.json filter=lfs diff=lfs merge=lfs -text
GOAL_github/LISM/DCIsamwithclipmaxsimmaintainbackrefsentencefortestset.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from io import BytesIO
3
+ from PIL import Image
4
+ import numpy as np
5
+ from transformers import pipeline, CLIPProcessor, CLIPModel
6
+ import cv2
7
+ import torch
8
+ import json
9
+ import os
10
+ from tqdm import tqdm
11
+
12
+ # SAM 모델 로드
13
+ generator = pipeline("mask-generation", device=2, points_per_batch=128)
14
+
15
+ # CLIP 모델 로드
16
+ model_name = "openai/clip-vit-large-patch14-336"
17
+ clip_model = CLIPModel.from_pretrained(model_name)
18
+ clip_processor = CLIPProcessor.from_pretrained(model_name)
19
+
20
+ # 원본 배경 유지
21
+ def keep_original_background(image, mask):
22
+ masked_image = np.array(image)
23
+ mask_3d = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
24
+ masked_image = masked_image * mask_3d + (1 - mask_3d) * np.array(image)
25
+ return Image.fromarray(np.uint8(masked_image))
26
+
27
+ # 이미지 리사이징 함수
28
+ def resize_image_if_needed(image_path, max_size=(2048, 1536)):
29
+ with Image.open(image_path) as img:
30
+ if img.size[0] > 4000 or img.size[1] > 3000:
31
+ img.thumbnail(max_size, Image.LANCZOS)
32
+ return img.copy()
33
+ return img.copy()
34
+
35
+ # CLIP을 사용한 이미지-캡션 매칭
36
+ def get_clip_similarity(images, texts):
37
+ try:
38
+ inputs = clip_processor(text=texts, images=images, return_tensors="pt", padding=True)
39
+ outputs = clip_model(**inputs)
40
+ return outputs.logits_per_image.cpu().detach().numpy()
41
+ except Exception as e:
42
+ print(f"Error in CLIP processing: {str(e)}")
43
+ return None
44
+
45
+ # 간단한 문장 토큰화
46
+ def tokenize_caption(text):
47
+ return [s.strip() for s in text.split('.') if s.strip()]
48
+
49
+ # 기본 경로 설정
50
+ matching_results_path = "matching_results"
51
+ output_base_dir = "segment_with_background_DCI_test_set_max_0.01"
52
+ os.makedirs(output_base_dir, exist_ok=True)
53
+
54
+ # train_sa로 시작하는 폴더들 찾기
55
+ train_folders = [f for f in os.listdir(matching_results_path) if f.startswith('test_sa_')]
56
+
57
+ # 각 train 폴더에 대해 처리
58
+ for train_folder in tqdm(train_folders, desc="Processing folders"):
59
+ folder_path = os.path.join(matching_results_path, train_folder)
60
+
61
+ # JSON 파일 찾기 (sa_xxxxxx_result.json)
62
+ json_files = [f for f in os.listdir(folder_path) if f.endswith('_result.json')]
63
+ if not json_files:
64
+ continue
65
+
66
+ json_path = os.path.join(folder_path, json_files[0])
67
+
68
+ # JSON 파일 로드
69
+ try:
70
+ with open(json_path, 'r') as f:
71
+ annotation = json.load(f)
72
+ except Exception as e:
73
+ print(f"Error loading JSON for {train_folder}: {str(e)}")
74
+ continue
75
+
76
+ image_filename = annotation['original_image']
77
+ image_path = os.path.join(folder_path, image_filename)
78
+ caption = annotation['extra_caption']
79
+
80
+ # 결과를 저장할 디렉토리 설정
81
+ output_dir = os.path.join(output_base_dir, f"{image_filename}_results")
82
+ os.makedirs(output_dir, exist_ok=True)
83
+
84
+ # 이미 처리된 이미지는 건너뛰기
85
+ if os.path.exists(os.path.join(output_dir, 'matched_dataset_test.json')):
86
+ continue
87
+
88
+ # 이미지 로드 및 세그멘테이션
89
+ try:
90
+ original_image = Image.open(image_path)
91
+ resized_image = resize_image_if_needed(image_path)
92
+
93
+ # 리사이즈된 이미지로 SAM 실행
94
+ outputs = generator(resized_image, points_per_batch=128)
95
+
96
+ # 원본 이미지 크기로 마스크 리사이즈
97
+ if original_image.size != resized_image.size:
98
+ for i in range(len(outputs['masks'])):
99
+ mask = Image.fromarray(outputs['masks'][i])
100
+ mask = mask.resize(original_image.size, Image.NEAREST)
101
+ outputs['masks'][i] = np.array(mask)
102
+
103
+ image = original_image
104
+ except Exception as e:
105
+ print(f"Error processing {image_filename}: {str(e)}")
106
+ continue
107
+
108
+ # 필터링 및 세그먼트 이미지 생성
109
+ min_area_ratio = 0.01
110
+ max_area_ratio = 0.8
111
+ total_area = image.size[0] * image.size[1]
112
+ segmented_images = [image]
113
+ segmented_image_paths = ["original_image.jpg"]
114
+
115
+ for i, mask in enumerate(outputs['masks']):
116
+ segment_area = np.sum(mask)
117
+ area_ratio = segment_area / total_area
118
+ if min_area_ratio <= area_ratio <= max_area_ratio:
119
+ masked_image = keep_original_background(image, mask)
120
+
121
+ y, x = np.where(mask)
122
+ y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
123
+ cropped_image = np.array(masked_image)[y_min:y_max+1, x_min:x_max+1]
124
+
125
+ segmented_images.append(Image.fromarray(np.uint8(cropped_image)))
126
+ segmented_image_paths.append(f"{image_filename.split('.')[0]}_max_{i+1}.png")
127
+
128
+ # 캡션 토큰화
129
+ tokenized_captions = tokenize_caption(caption)
130
+
131
+ # 모든 문장-이미지 쌍에 대한 유사도 계산
132
+ similarity_matrix = get_clip_similarity(segmented_images, tokenized_captions)
133
+
134
+ if similarity_matrix is None:
135
+ print(f"Skipping image {image_filename} due to CLIP processing error")
136
+ continue
137
+
138
+ matched_data = []
139
+ saved_images = set()
140
+
141
+ # 모든 문장에 대해 가장 높은 유사도를 가진 이미지와 매칭
142
+ for j, caption in enumerate(tokenized_captions):
143
+ i = np.argmax(similarity_matrix[:, j])
144
+ similarity_score = float(similarity_matrix[i, j])
145
+ matched_image_path = segmented_image_paths[i]
146
+
147
+ matched_data.append({
148
+ "caption": caption,
149
+ "matched_image_path": matched_image_path,
150
+ "similarity_score": similarity_score
151
+ })
152
+
153
+ saved_images.add(i)
154
+
155
+ # 원본 이미지 저장
156
+ original_image_path = os.path.join(output_dir, "original_image.jpg")
157
+ image.save(original_image_path)
158
+
159
+ # 매칭된 세그먼트 이미지 저장
160
+ for i in saved_images:
161
+ if i != 0: # 원본 이미지는 건너뛰기
162
+ output_file_path = os.path.join(output_dir, segmented_image_paths[i])
163
+ segmented_images[i].save(output_file_path)
164
+
165
+ # 결과 저장
166
+ json_output_path = os.path.join(output_dir, 'matched_dataset_test.json')
167
+ with open(json_output_path, 'w') as f:
168
+ json.dump(matched_data, f, indent=2)
169
+
170
+ print("모든 test set 이미지 처리가 완료되었습니다.")
GOAL_github/LISM/DCIsamwithclipmaxsimmaintainbackrefsentencefortrainset.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from io import BytesIO
3
+ from PIL import Image
4
+ import numpy as np
5
+ from transformers import pipeline, CLIPProcessor, CLIPModel
6
+ import cv2
7
+ import torch
8
+ import json
9
+ import os
10
+ from tqdm import tqdm
11
+
12
+ # SAM 모델 로드
13
+ generator = pipeline("mask-generation", device=1, points_per_batch=128)
14
+
15
+ # CLIP 모델 로드
16
+ model_name = "openai/clip-vit-large-patch14-336"
17
+ clip_model = CLIPModel.from_pretrained(model_name)
18
+ clip_processor = CLIPProcessor.from_pretrained(model_name)
19
+
20
+ # 원본 배경 유지
21
+ def keep_original_background(image, mask):
22
+ masked_image = np.array(image)
23
+ mask_3d = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
24
+ masked_image = masked_image * mask_3d + (1 - mask_3d) * np.array(image)
25
+ return Image.fromarray(np.uint8(masked_image))
26
+
27
+ # 이미지 리사이징 함수
28
+ def resize_image_if_needed(image_path, max_size=(2048, 1536)):
29
+ with Image.open(image_path) as img:
30
+ if img.size[0] > 4000 or img.size[1] > 3000:
31
+ img.thumbnail(max_size, Image.LANCZOS)
32
+ return img.copy()
33
+ return img.copy()
34
+
35
+ # CLIP을 사용한 이미지-캡션 매칭
36
+ def get_clip_similarity(images, texts):
37
+ try:
38
+ inputs = clip_processor(text=texts, images=images, return_tensors="pt", padding=True)
39
+ outputs = clip_model(**inputs)
40
+ return outputs.logits_per_image.cpu().detach().numpy()
41
+ except Exception as e:
42
+ print(f"Error in CLIP processing: {str(e)}")
43
+ return None
44
+
45
+ # 간단한 문장 토큰화
46
+ def tokenize_caption(text):
47
+ return [s.strip() for s in text.split('.') if s.strip()]
48
+
49
+ # 기본 경로 설정
50
+ matching_results_path = "matching_results"
51
+ output_base_dir = "segment_with_background_DCI_train_set_max_0.01"
52
+ os.makedirs(output_base_dir, exist_ok=True)
53
+
54
+ # train_sa로 시작하는 폴더들 찾기
55
+ train_folders = [f for f in os.listdir(matching_results_path) if f.startswith('train_sa_')]
56
+
57
+ # 각 train 폴더에 대해 처리
58
+ for train_folder in tqdm(train_folders, desc="Processing folders"):
59
+ folder_path = os.path.join(matching_results_path, train_folder)
60
+
61
+ # JSON 파일 찾기 (sa_xxxxxx_result.json)
62
+ json_files = [f for f in os.listdir(folder_path) if f.endswith('_result.json')]
63
+ if not json_files:
64
+ continue
65
+
66
+ json_path = os.path.join(folder_path, json_files[0])
67
+
68
+ # JSON 파일 로드
69
+ try:
70
+ with open(json_path, 'r') as f:
71
+ annotation = json.load(f)
72
+ except Exception as e:
73
+ print(f"Error loading JSON for {train_folder}: {str(e)}")
74
+ continue
75
+
76
+ image_filename = annotation['original_image']
77
+ image_path = os.path.join(folder_path, image_filename)
78
+ caption = annotation['extra_caption']
79
+
80
+ # 결과를 저장할 디렉토리 설정
81
+ output_dir = os.path.join(output_base_dir, f"{image_filename}_results")
82
+ os.makedirs(output_dir, exist_ok=True)
83
+
84
+ # 이미 처리된 이미지는 건너뛰기
85
+ if os.path.exists(os.path.join(output_dir, 'matched_dataset_train.json')):
86
+ continue
87
+
88
+ # 이미지 로드 및 세그멘테이션
89
+ try:
90
+ original_image = Image.open(image_path)
91
+ resized_image = resize_image_if_needed(image_path)
92
+
93
+ # 리사이즈된 이미지로 SAM 실행
94
+ outputs = generator(resized_image, points_per_batch=128)
95
+
96
+ # 원본 이미지 크기로 마스크 리사이즈
97
+ if original_image.size != resized_image.size:
98
+ for i in range(len(outputs['masks'])):
99
+ mask = Image.fromarray(outputs['masks'][i])
100
+ mask = mask.resize(original_image.size, Image.NEAREST)
101
+ outputs['masks'][i] = np.array(mask)
102
+
103
+ image = original_image
104
+ except Exception as e:
105
+ print(f"Error processing {image_filename}: {str(e)}")
106
+ continue
107
+
108
+ # 필터링 및 세그먼트 이미지 생성
109
+ min_area_ratio = 0.01
110
+ max_area_ratio = 0.8
111
+ total_area = image.size[0] * image.size[1]
112
+ segmented_images = [image]
113
+ segmented_image_paths = ["original_image.jpg"]
114
+
115
+ for i, mask in enumerate(outputs['masks']):
116
+ segment_area = np.sum(mask)
117
+ area_ratio = segment_area / total_area
118
+ if min_area_ratio <= area_ratio <= max_area_ratio:
119
+ masked_image = keep_original_background(image, mask)
120
+
121
+ y, x = np.where(mask)
122
+ y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
123
+ cropped_image = np.array(masked_image)[y_min:y_max+1, x_min:x_max+1]
124
+
125
+ segmented_images.append(Image.fromarray(np.uint8(cropped_image)))
126
+ segmented_image_paths.append(f"{image_filename.split('.')[0]}_max_{i+1}.png")
127
+
128
+ # 캡션 토큰화
129
+ tokenized_captions = tokenize_caption(caption)
130
+
131
+ # 모든 문장-이미지 쌍에 대한 유사도 계산
132
+ similarity_matrix = get_clip_similarity(segmented_images, tokenized_captions)
133
+
134
+ if similarity_matrix is None:
135
+ print(f"Skipping image {image_filename} due to CLIP processing error")
136
+ continue
137
+
138
+ matched_data = []
139
+ saved_images = set()
140
+
141
+ # 모든 문장에 대해 가장 높은 유사도를 가진 이미지와 매칭
142
+ for j, caption in enumerate(tokenized_captions):
143
+ i = np.argmax(similarity_matrix[:, j])
144
+ similarity_score = float(similarity_matrix[i, j])
145
+ matched_image_path = segmented_image_paths[i]
146
+
147
+ matched_data.append({
148
+ "caption": caption,
149
+ "matched_image_path": matched_image_path,
150
+ "similarity_score": similarity_score
151
+ })
152
+
153
+ saved_images.add(i)
154
+
155
+ # 원본 이미지 저장
156
+ original_image_path = os.path.join(output_dir, "original_image.jpg")
157
+ image.save(original_image_path)
158
+
159
+ # 매칭된 세그먼트 이미지 저장
160
+ for i in saved_images:
161
+ if i != 0: # 원본 이미지는 건너뛰기
162
+ output_file_path = os.path.join(output_dir, segmented_image_paths[i])
163
+ segmented_images[i].save(output_file_path)
164
+
165
+ # 결과 저장
166
+ json_output_path = os.path.join(output_dir, 'matched_dataset_train.json')
167
+ with open(json_output_path, 'w') as f:
168
+ json.dump(matched_data, f, indent=2)
169
+
170
+ print("모든 train set 이미지 처리가 완료되었습니다.")
GOAL_github/datasets/DCI_segment_only_sim_max_del_org.json ADDED
The diff for this file is too large to render. See raw diff
 
GOAL_github/datasets/DCI_test.json ADDED
The diff for this file is too large to render. See raw diff
 
GOAL_github/datasets/DCI_test_joint_sim_max_1 ADDED
File without changes
GOAL_github/datasets/DCI_train_del_org.json ADDED
The diff for this file is too large to render. See raw diff
 
GOAL_github/datasets/docci_segment_sim_bbox_del_org.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ce1febc6eca26cb77936fdfb2f4fc6aa8c30eb52172505e22a015d06757a7bc
3
+ size 35738701
GOAL_github/datasets/docci_test.json ADDED
The diff for this file is too large to render. See raw diff
 
GOAL_github/datasets/docci_test_joint_sim_max_1 ADDED
File without changes
GOAL_github/datasets/docci_train_del_org.json ADDED
The diff for this file is too large to render. See raw diff
 
GOAL_github/datasets/urban_dataset_test.json ADDED
The diff for this file is too large to render. See raw diff
 
GOAL_github/goal.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
4
+ import json
5
+ import argparse
6
+ from PIL import Image
7
+ from pathlib import Path
8
+ from torch.utils.data import Dataset
9
+ import lightning as L
10
+ import transformers
11
+ import torch.nn.functional as F
12
+ import shutil
13
+ import time
14
+ import numpy as np
15
+ from utils.func import *
16
+ from utils.transforms import *
17
+ from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
18
+ import shutil
19
+ import math
20
+ import random
21
+ import wandb
22
+
23
+ torch.autograd.set_detect_anomaly(True) # Enable anomaly detection
24
+
25
+ def clip_loss(sim):
26
+ gt = torch.arange(len(sim), dtype=torch.long, device=sim.device)
27
+ return (torch.nn.CrossEntropyLoss()(sim, gt) + torch.nn.CrossEntropyLoss()(sim.t(), gt)) / 2.0
28
+
29
+ def get_patch_tokens_from_bbox(patch_tokens, bbox, b, original_image_size, image_size=224, patch_size=16):
30
+ # Get original dimensions from actual image size
31
+ org_width, org_height = original_image_size
32
+
33
+ # Scale coordinates to image_size
34
+ x1 = int(round(bbox['x1'][b].item() * image_size / org_width))
35
+ y1 = int(round(bbox['y1'][b].item() * image_size / org_height))
36
+ x2 = int(round(bbox['x2'][b].item() * image_size / org_width))
37
+ y2 = int(round(bbox['y2'][b].item() * image_size / org_height))
38
+
39
+ # Ensure coordinates are within image bounds
40
+ x1 = max(0, min(x1, image_size-1))
41
+ y1 = max(0, min(y1, image_size-1))
42
+ x2 = max(0, min(x2, image_size))
43
+ y2 = max(0, min(y2, image_size))
44
+
45
+ # Convert to patch indices (include any patch that the bbox touches)
46
+ patch_x1 = x1 // patch_size
47
+ patch_y1 = y1 // patch_size
48
+ patch_x2 = (x2 + patch_size - 1) // patch_size
49
+ patch_y2 = (y2 + patch_size - 1) // patch_size
50
+
51
+ # Get indices of patches
52
+ num_patches = (image_size // patch_size)
53
+ indices = []
54
+ for i in range(patch_y1, patch_y2):
55
+ for j in range(patch_x1, patch_x2):
56
+ indices.append(i * num_patches + j + 1)
57
+
58
+ # Extract and pool relevant patch tokens
59
+ relevant_tokens = patch_tokens[:, indices, :]
60
+ pooled_tokens = torch.mean(relevant_tokens, dim=1)
61
+
62
+ return pooled_tokens
63
+
64
+ def get_text_tokens_from_segment(text_tokens, org_text, seg_text, processor):
65
+ """
66
+ Args:
67
+ text_tokens: (B, L, D) tensor of text tokens - all tokens of original text
68
+ org_text: original text string
69
+ seg_text: segment text string
70
+ processor: CLIP processor
71
+ Returns:
72
+ pooled_tokens: (B, D) tensor of pooled text tokens from the relevant segment
73
+ """
74
+ # Text preprocessing
75
+ org_text = ' '.join(org_text.split()).strip()
76
+ seg_text = ' '.join(seg_text.split()).strip()
77
+
78
+ # Split org_text into sentences
79
+ sentences = org_text.split('.')
80
+ sentences = [s.strip() for s in sentences if s.strip()]
81
+
82
+ # Find seg_text position
83
+ seg_pos = org_text.find(seg_text)
84
+ current_pos = 0
85
+ sent_idx = -1
86
+
87
+ # Find position by sentence
88
+ for i, sent in enumerate(sentences):
89
+ sent = sent.strip()
90
+ if sent == seg_text:
91
+ seg_pos = current_pos
92
+ sent_idx = i
93
+ break
94
+ current_pos += len(sent) + 2
95
+
96
+ assert seg_pos != -1, f"Segment text not found in original text"
97
+
98
+ # Tokenize segment text
99
+ seg_tokens = processor(text=seg_text,
100
+ return_tensors="pt",
101
+ padding=False,
102
+ truncation=False)
103
+ seg_token_length = len(seg_tokens.input_ids[0]) - 2 # Exclude CLS, EOS tokens
104
+
105
+ if sent_idx != -1:
106
+ # Calculate token index based on sentence position
107
+ text_before = '. '.join(sentences[:sent_idx]) + ('. ' if sent_idx > 0 else '')
108
+ tokens_before = processor(text=text_before,
109
+ return_tensors="pt",
110
+ padding=False,
111
+ truncation=False)
112
+ start_idx = len(tokens_before.input_ids[0])
113
+ else:
114
+ # Calculate token index based on string position
115
+ text_before = org_text[:seg_pos]
116
+ tokens_before = processor(text=text_before,
117
+ return_tensors="pt",
118
+ padding=False,
119
+ truncation=False)
120
+ start_idx = len(tokens_before.input_ids[0])
121
+
122
+ # Adjust range considering maximum token length
123
+ max_length = text_tokens.shape[1] # 248
124
+ if start_idx >= max_length:
125
+ # If segment is at a position beyond max length,
126
+ # extract tokens from the end, securing space equal to segment length
127
+ end_idx = max_length - 1
128
+ start_idx = max(1, end_idx - seg_token_length) # Start from after CLS token (1)
129
+ else:
130
+ # If within normal range
131
+ end_idx = min(start_idx + seg_token_length, max_length - 1)
132
+
133
+ # Extract tokens
134
+ relevant_tokens = text_tokens[:, start_idx:end_idx, :]
135
+
136
+ # Handle case when no tokens are extracted
137
+ if relevant_tokens.shape[1] == 0:
138
+ # Fallback: use tokens from the beginning
139
+ relevant_tokens = text_tokens[:, 1:min(1 + seg_token_length, max_length), :]
140
+
141
+ # Pool tokens
142
+ pooled_tokens = torch.mean(relevant_tokens, dim=1)
143
+
144
+ return pooled_tokens
145
+
146
+ class DLoader(Dataset):
147
+ def __init__(self, data_list, processor, new_max_token):
148
+ self.data_list = data_list
149
+ self.processor = processor
150
+ self.new_max_token = new_max_token
151
+
152
+ def __len__(self):
153
+ return len(self.data_list)
154
+
155
+ def _load_image(self, name):
156
+ img = Image.open(name).convert("RGB")
157
+ return img, img.size # Also return original image size
158
+
159
+ def __getitem__(self, idx):
160
+ if torch.is_tensor(idx):
161
+ idx = idx.tolist()
162
+
163
+ item = self.data_list[idx]
164
+ org_image, org_image_size = self._load_image(item["original_filename"]) # Get original image size
165
+ org_caption = item["original_caption"]
166
+
167
+ # Always select the segment with the highest similarity score
168
+ segment = max(item["segment"], key=lambda x: x["similarity_score"])
169
+
170
+ seg_image = self._load_image(segment["filename"])[0]
171
+ seg_caption = segment["caption"]
172
+ bbox = segment["bbox_coordinates"]
173
+
174
+ org_data = self.processor(images=org_image, text=org_caption, return_tensors="pt",
175
+ truncation=True, padding="max_length", max_length=self.new_max_token)
176
+ seg_data = self.processor(images=seg_image, text=seg_caption, return_tensors="pt",
177
+ truncation=True, padding="max_length", max_length=self.new_max_token)
178
+
179
+ return (org_data.pixel_values[0], org_data.input_ids[0],
180
+ seg_data.pixel_values[0], seg_data.input_ids[0],
181
+ bbox, org_caption, seg_caption, org_image_size,
182
+ item["original_filename"], segment["filename"])
183
+
184
+ def main(args):
185
+ wandb.init(project="CLIP_Training_real", config=args)
186
+
187
+ fabric = L.Fabric(
188
+ accelerator="cuda",
189
+ devices=args.world_size,
190
+ strategy="ddp",
191
+ precision="bf16"
192
+ )
193
+
194
+ fabric.launch()
195
+ fabric.seed_everything(1337 + fabric.global_rank)
196
+
197
+ if fabric.global_rank == 0:
198
+ os.makedirs(args.output_dir, exist_ok=True)
199
+
200
+ with open(args.dataset) as f:
201
+ train_list = json.load(f)
202
+
203
+ with fabric.device:
204
+ processor = transformers.AutoProcessor.from_pretrained(args.model)
205
+ model = transformers.CLIPModel.from_pretrained(args.model)
206
+ longclip_pos_embeddings(model, args.new_max_token)
207
+
208
+ # Load checkpoint if provided
209
+ if args.ckpt:
210
+ if fabric.global_rank == 0:
211
+ print(f"Loading checkpoint from {args.ckpt}")
212
+ checkpoint = torch.load(args.ckpt, map_location='cpu')
213
+ model.load_state_dict(checkpoint)
214
+ if fabric.global_rank == 0:
215
+ print("Checkpoint loaded successfully")
216
+
217
+ print_trainable_parameters(fabric, model)
218
+
219
+ dataset_train = DLoader(train_list, processor, args.new_max_token)
220
+
221
+ train_loader = torch.utils.data.DataLoader(
222
+ dataset_train, batch_size=args.batch_size,
223
+ num_workers=args.num_workers,
224
+ pin_memory=args.pin_mem,
225
+ drop_last=True,
226
+ shuffle=True,
227
+ )
228
+
229
+ train_loader = fabric.setup_dataloaders(train_loader)
230
+
231
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.init_lr, weight_decay=args.weight_decay)
232
+ model, optimizer = fabric.setup(model, optimizer)
233
+
234
+ train(fabric, model, optimizer, train_loader, processor)
235
+
236
+ def train(fabric: L.Fabric, model: torch.nn.Module, optimizer: torch.optim.Optimizer, train_loader, processor) -> None:
237
+ iter = 0
238
+ total_iter = len(train_loader) * args.epochs
239
+
240
+ # Define MSE Loss
241
+ mse_loss = torch.nn.MSELoss()
242
+
243
+ for epoch in range(args.epochs):
244
+ epoch_loss = 0.0
245
+ epoch_loss_org = 0.0
246
+ epoch_loss_seg = 0.0
247
+ epoch_loss_patch = 0.0
248
+ epoch_loss_text = 0.0
249
+
250
+ for i, samples in enumerate(train_loader):
251
+ # Cosine LR
252
+ lr = (args.init_lr - args.min_lr) * 0.5 * (1.0 + math.cos(math.pi * iter / total_iter)) + args.min_lr
253
+ for param_group in optimizer.param_groups:
254
+ param_group["lr"] = lr
255
+
256
+ org_image, org_text, seg_image, seg_text, bbox, org_caption, seg_caption, org_image_sizes, org_image_paths, seg_image_paths = samples
257
+
258
+ # Get all embeddings including patch tokens and sequence tokens
259
+ outputs = model(pixel_values=torch.cat((org_image, seg_image), dim=0),
260
+ input_ids=torch.cat((org_text, seg_text), dim=0),
261
+ output_hidden_states=True)
262
+ # print(model.text_model.embeddings.position_embedding.weight.requires_grad)
263
+ # Get patch tokens and text tokens
264
+ vision_outputs = model.vision_model(torch.cat((org_image, seg_image), dim=0), output_hidden_states=True)
265
+ text_outputs = model.text_model(torch.cat((org_text, seg_text), dim=0), output_hidden_states=True)
266
+
267
+ # Split embeddings for org and seg
268
+ batch_size = org_image.shape[0]
269
+ org_image_embeds, seg_image_embeds = outputs.image_embeds[:batch_size], outputs.image_embeds[batch_size:]
270
+ org_text_embeds, seg_text_embeds = outputs.text_embeds[:batch_size], outputs.text_embeds[batch_size:]
271
+
272
+ # Get patch tokens and text tokens from the last hidden states
273
+ org_patch_tokens = vision_outputs.hidden_states[-1][:batch_size] # (B, N, D)
274
+ org_text_tokens = text_outputs.hidden_states[-1][:batch_size] # (B, L, D)
275
+
276
+ # Original CLIP loss
277
+ eps = 1e-8
278
+ x_i = batch_align(fabric, F.normalize(outputs.image_embeds + eps))
279
+ x_t = batch_align(fabric, F.normalize(outputs.text_embeds + eps))
280
+ x_i_org, x_i_seg = x_i.chunk(2)
281
+ x_t_org, x_t_seg = x_t.chunk(2)
282
+
283
+ # Compute original losses
284
+ sim_org = model.logit_scale.exp() * x_i_org @ x_t_org.t()
285
+ loss_org = clip_loss(sim_org)
286
+ sim_seg = model.logit_scale.exp() * x_i_seg @ x_t_seg.t()
287
+ loss_seg = clip_loss(sim_seg)
288
+
289
+ # Compute patch-level alignment loss
290
+ patch_pooled = []
291
+ for b in range(batch_size):
292
+ # org_image_sizes is converted to [width_tensor, height_tensor] format
293
+ # Original format: (width, height) tuple
294
+ img_width = org_image_sizes[0][b].item() # b-th element from width tensor
295
+ img_height = org_image_sizes[1][b].item() # b-th element from height tensor
296
+ img_size = (img_width, img_height)
297
+
298
+ pooled = get_patch_tokens_from_bbox(org_patch_tokens[b:b+1],
299
+ bbox,
300
+ b,
301
+ img_size,
302
+ image_size=args.image_size,
303
+ patch_size=16)
304
+ patch_pooled.append(pooled)
305
+
306
+ patch_pooled = torch.cat(patch_pooled, dim=0)
307
+ patch_pooled = model.vision_model.post_layernorm(patch_pooled)
308
+ patch_pooled = model.visual_projection(patch_pooled)
309
+ patch_pooled = F.normalize(patch_pooled + eps, dim=-1)
310
+ seg_image_embeds = F.normalize(seg_image_embeds + eps, dim=-1)
311
+
312
+ # Compute patch alignment loss with cosine similarity directly
313
+ sim_patch = patch_pooled @ seg_image_embeds.t() # removed logit_scale
314
+ patch_diag = torch.diag(sim_patch)
315
+ loss_patch = mse_loss(patch_diag, torch.ones_like(patch_diag))
316
+
317
+ # Compute text-level alignment loss
318
+ text_pooled = []
319
+ for b in range(batch_size):
320
+ #print(f"\nBatch {b} Text Sequences:")
321
+
322
+ # Full token IDs of org_text
323
+ org_tokens = processor(text=org_caption[b],
324
+ return_tensors="pt",
325
+ padding=False,
326
+ truncation=False)
327
+ org_token_ids = org_tokens.input_ids[0]
328
+
329
+ # Full token IDs of seg_text
330
+ seg_tokens = processor(text=seg_caption[b],
331
+ return_tensors="pt",
332
+ padding=False,
333
+ truncation=False)
334
+ seg_token_ids = seg_tokens.input_ids[0]
335
+
336
+ # Decode token IDs to text
337
+ org_tokens_text = processor.tokenizer.convert_ids_to_tokens(org_token_ids)
338
+ seg_tokens_text = processor.tokenizer.convert_ids_to_tokens(seg_token_ids)
339
+
340
+ # Confirm position of tokens extracted by get_text_tokens_from_segment function
341
+ start_idx = len(processor(text=org_caption[b][:org_caption[b].find(seg_caption[b])],
342
+ return_tensors="pt",
343
+ padding=False,
344
+ truncation=False).input_ids[0])
345
+
346
+ end_idx = start_idx + len(seg_tokens.input_ids[0]) - 2 # Exclude CLS, EOS tokens
347
+
348
+ pooled = get_text_tokens_from_segment(org_text_tokens[b:b+1],
349
+ org_caption[b],
350
+ seg_caption[b],
351
+ processor)
352
+ text_pooled.append(pooled)
353
+ text_pooled = torch.cat(text_pooled, dim=0)
354
+
355
+ text_pooled = model.text_model.final_layer_norm(text_pooled)
356
+ text_pooled = model.text_projection(text_pooled)
357
+ text_pooled = F.normalize(text_pooled + eps, dim=-1)
358
+
359
+ seg_text_embeds = F.normalize(seg_text_embeds + eps, dim=-1)
360
+
361
+ # Compute text alignment loss with cosine similarity directly
362
+ sim_text = text_pooled @ seg_text_embeds.t() # removed logit_scale
363
+ text_diag = torch.diag(sim_text)
364
+ loss_text = mse_loss(text_diag, torch.ones_like(text_diag))
365
+
366
+ # Total loss
367
+ loss = loss_org + 0.5 * loss_seg + loss_patch + loss_text
368
+
369
+ epoch_loss += loss.item()
370
+ epoch_loss_org += loss_org.item()
371
+ epoch_loss_seg += loss_seg.item()
372
+ epoch_loss_patch += loss_patch.item()
373
+ epoch_loss_text += loss_text.item()
374
+
375
+ fabric.backward(loss)
376
+ optimizer.step()
377
+ optimizer.zero_grad()
378
+
379
+ if fabric.global_rank == 0:
380
+ wandb.log({
381
+ "iter": iter,
382
+ "lr": lr,
383
+ "loss": loss.item(),
384
+ "loss_org": loss_org.item(),
385
+ "loss_seg": loss_seg.item(),
386
+ "loss_patch": loss_patch.item(),
387
+ "loss_text": loss_text.item(),
388
+ "epoch": epoch,
389
+ "progress": (iter / total_iter) * 100,
390
+ "batch_size": args.batch_size,
391
+ "logit_scale": model.logit_scale.exp().item(),
392
+ "patch_similarity": patch_diag.mean().item(), # average patch similarity
393
+ "text_similarity": text_diag.mean().item(), # average text similarity
394
+ })
395
+
396
+ fabric.print(f"epoch {epoch} iter {iter} ({(iter/total_iter)*100:.4f}%) lr {lr:.6f} "
397
+ f"loss {loss.item():.4f} (org: {loss_org.item():.4f}, seg: {loss_seg.item():.4f}, "
398
+ f"patch: {loss_patch.item():.4f}, text: {loss_text.item():.4f} "
399
+ f"patch_sim: {patch_diag.mean().item():.4f}, text_sim: {text_diag.mean().item():.4f})")
400
+ iter += 1
401
+
402
+ # Calculate and log epoch averages
403
+ avg_epoch_loss = epoch_loss / len(train_loader)
404
+ avg_epoch_loss_org = epoch_loss_org / len(train_loader)
405
+ avg_epoch_loss_seg = epoch_loss_seg / len(train_loader)
406
+ avg_epoch_loss_patch = epoch_loss_patch / len(train_loader)
407
+ avg_epoch_loss_text = epoch_loss_text / len(train_loader)
408
+
409
+ if fabric.global_rank == 0:
410
+ wandb.log({
411
+ "epoch": epoch,
412
+ "avg_epoch_loss": avg_epoch_loss,
413
+ "avg_epoch_loss_org": avg_epoch_loss_org,
414
+ "avg_epoch_loss_seg": avg_epoch_loss_seg,
415
+ "avg_epoch_loss_patch": avg_epoch_loss_patch,
416
+ "avg_epoch_loss_text": avg_epoch_loss_text,
417
+ })
418
+
419
+ # Save model weights
420
+ save_path = os.path.join(args.output_dir,
421
+ f"GOAL_12_{os.path.splitext(os.path.basename(args.model))[0]}_"
422
+ f"{os.path.splitext(os.path.basename(args.dataset))[0]}_{epoch+1}_{args.image_size}.pth")
423
+
424
+ fabric.barrier()
425
+ if fabric.global_rank == 0:
426
+ model_state_dict = model.state_dict()
427
+ cpu_state_dict = {k: v.cpu() for k, v in model_state_dict.items()}
428
+ torch.save(cpu_state_dict, save_path)
429
+ fabric.print(f"Model saved to {save_path}")
430
+ fabric.barrier()
431
+
432
+
433
+ def get_args_parser():
434
+ parser = argparse.ArgumentParser('CLIP Training', add_help=False)
435
+ parser.add_argument('--batch_size', default=16, type=int,
436
+ help='Batch size per GPU')
437
+ parser.add_argument('--epochs', default=10, type=int)
438
+ parser.add_argument('--image_size', default=224, type=int)
439
+ parser.add_argument('--new_max_token', default=248, type=int)
440
+ parser.add_argument('--dataset', default='datasets/docci_segment_sim_bbox_del_org.json', type=str)
441
+ parser.add_argument('--model', default='openai/clip-vit-base-patch16', type=str)
442
+ parser.add_argument('--weight_decay', type=float, default=0.05)
443
+ parser.add_argument('--init_lr', type=float, default=5e-6, metavar='LR') # originally 5e-6
444
+ parser.add_argument('--min_lr', type=float, default=0, metavar='LR')
445
+ parser.add_argument('--output_dir', default='finetune_out_SA_1B_100k_plus_docci/goal_bbox_local_token_align_batch16_only_max_pair_base16_patch16_real',
446
+ help='path where to save, empty for no saving')
447
+ parser.add_argument('--save_interval', default=1, type=int)
448
+ parser.add_argument('--seed', default=0, type=int)
449
+ parser.add_argument('--num_workers', default=10, type=int)
450
+ parser.add_argument('--pin_mem', action='store_true',
451
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
452
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
453
+ parser.add_argument('--wandb_project', type=str, default='CLIP_Training', help='wandb project name')
454
+ parser.add_argument('--ckpt', type=str, default=None, help='path to checkpoint file')
455
+ parser.set_defaults(pin_mem=True)
456
+ parser.set_defaults(pin_mem=True)
457
+
458
+ # distributed training parameters
459
+ parser.add_argument('--world_size', default=1, type=int,
460
+ help='number of distributed processes')
461
+
462
+ return parser
463
+
464
+ if __name__ == "__main__":
465
+ args = get_args_parser()
466
+ args = args.parse_args()
467
+ if args.output_dir:
468
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
469
+ main(args)
GOAL_github/mAP_goal_jointtest.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ import os
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
5
+ import sys
6
+ from pathlib import Path
7
+ import lightning as L
8
+ import numpy as np
9
+ import torch
10
+ import transformers
11
+ from torch.utils.data import DataLoader, Dataset, Subset
12
+ from PIL import Image
13
+ import torch.nn.functional as F
14
+ from utils.func import *
15
+ from utils.transforms import *
16
+
17
+ # Hyperparameters
18
+ micro_batch_size = 32
19
+ devices = 1
20
+ num_workers = 1
21
+
22
+ class QueryLoader(Dataset):
23
+ def __init__(self, data_list, processor):
24
+ self.data_list = data_list
25
+ self.processor = processor
26
+ self.image_to_segs, self.org_images, self.filename_to_caption = self._create_mappings()
27
+
28
+ def _create_mappings(self):
29
+ image_to_segs = {}
30
+ org_images = set()
31
+ filename_to_caption = {}
32
+ for item in self.data_list:
33
+ filename = item['filename']
34
+ filename_to_caption[filename] = item['caption']
35
+ if 'segment_with_background' in filename:
36
+ org_filename = get_org_filename(filename)
37
+ if org_filename not in image_to_segs:
38
+ image_to_segs[org_filename] = []
39
+ image_to_segs[org_filename].append(item)
40
+ else:
41
+ org_images.add(filename)
42
+ image_to_segs[filename] = []
43
+ return image_to_segs, org_images, filename_to_caption
44
+
45
+ def __len__(self):
46
+ return len(self.data_list)
47
+
48
+ def _load_image(self, id: int):
49
+ return Image.open(self.data_list[id]["filename"]).convert("RGB")
50
+
51
+ def _load_target(self, id: int):
52
+ return self.data_list[id]["caption"]
53
+
54
+ def __getitem__(self, idx):
55
+ if torch.is_tensor(idx):
56
+ idx = idx.tolist()
57
+ image = self._load_image(idx)
58
+ caption = self._load_target(idx)
59
+ data = self.processor(images=image, text=caption, return_tensors="pt", truncation=True, padding="max_length", max_length=args.new_max_token)
60
+ return data.pixel_values[0], data.input_ids[0], self.data_list[idx]["filename"]
61
+
62
+ def process_chunk(fabric, model, data_loader):
63
+ images = []
64
+ texts = []
65
+ filenames = []
66
+ for samples in data_loader:
67
+ image, text, filename = samples
68
+ with torch.no_grad():
69
+ x = model(pixel_values=image.to(fabric.device), input_ids=text.to(fabric.device))
70
+ x_i = F.normalize(x.image_embeds)
71
+ x_t = F.normalize(x.text_embeds)
72
+ images.append(x_i)
73
+ texts.append(x_t)
74
+ filenames.extend(filename)
75
+ return torch.cat(images), torch.cat(texts), filenames
76
+
77
+ def compute_similarity(images, texts):
78
+ return torch.mm(images, texts.t())
79
+
80
+ def compute_ap(ranks, relevant_items, k=None):
81
+ if not relevant_items:
82
+ return 0.0
83
+ score = 0.0
84
+ num_hits = 0.0
85
+ for i, item in enumerate(ranks[:k] if k else ranks):
86
+ if item in relevant_items:
87
+ num_hits += 1.0
88
+ score += num_hits / (i + 1.0)
89
+ return score / len(relevant_items)
90
+
91
+ def get_org_filename(filename):
92
+ if 'segment_with_background' in filename:
93
+ parts = filename.split('/')
94
+ org_filename = parts[-2].split('_results')[0]
95
+ if not org_filename.endswith('.jpg'):
96
+ org_filename += '.jpg'
97
+ return org_filename
98
+ return filename
99
+
100
+ def get_relevant_items(filename, all_filenames, image_to_segs, org_images, filename_to_caption):
101
+ org_filename = get_org_filename(filename)
102
+
103
+ if filename in org_images:
104
+ # Query is an org image
105
+ relevant_items = [all_filenames.index(seg['filename']) for seg in image_to_segs[filename]]
106
+ else:
107
+ # Query is a seg image
108
+ relevant_items = []
109
+ if org_filename in all_filenames:
110
+ relevant_items.append(all_filenames.index(org_filename))
111
+ if filename in all_filenames:
112
+ relevant_items.append(all_filenames.index(filename))
113
+
114
+ return relevant_items
115
+
116
+ def get_relevant_captions(filename, image_to_segs, org_images, filename_to_caption):
117
+ org_filename = get_org_filename(filename)
118
+
119
+ if filename in org_images:
120
+ # Query is an org image
121
+ relevant_captions = [filename_to_caption[filename]] # 원본 이미지의 캡션 추가
122
+ relevant_captions.extend([seg['caption'] for seg in image_to_segs[filename]])
123
+ else:
124
+ # Query is a seg image
125
+ relevant_captions = [filename_to_caption[filename]]
126
+ if org_filename in filename_to_caption:
127
+ relevant_captions.append(filename_to_caption[org_filename])
128
+
129
+ return relevant_captions
130
+
131
+ def get_relevant_items_for_text(query_caption, all_filenames, image_to_segs, org_images, filename_to_caption):
132
+ relevant_items = []
133
+ for filename, caption in filename_to_caption.items():
134
+ if caption == query_caption:
135
+ if filename in org_images:
136
+ # 쿼리가 org 이미지의 캡션인 경우
137
+ relevant_items.append(all_filenames.index(filename)) # org 이미지
138
+ relevant_items.extend([all_filenames.index(seg['filename']) for seg in image_to_segs[filename]]) # seg 이미지들
139
+ else:
140
+ # 쿼리가 seg 이미지의 캡션인 경우
141
+ relevant_items.append(all_filenames.index(filename)) # seg 이미지
142
+ org_filename = get_org_filename(filename)
143
+ if org_filename in all_filenames:
144
+ relevant_items.append(all_filenames.index(org_filename)) # org 이미지
145
+ return list(set(relevant_items)) # 중복 제거
146
+
147
+ @torch.no_grad()
148
+ def test(fabric: L.Fabric, model: torch.nn.Module, query_loader, k=None) -> torch.Tensor:
149
+ fabric.print("Testing ...")
150
+
151
+ chunk_size = 5000
152
+ dataset_size = len(query_loader.dataset)
153
+ all_images = []
154
+ all_texts = []
155
+ all_filenames = []
156
+
157
+ for start_idx in range(0, dataset_size, chunk_size):
158
+ end_idx = min(start_idx + chunk_size, dataset_size)
159
+ chunk_dataset = Subset(query_loader.dataset, range(start_idx, end_idx))
160
+ chunk_loader = DataLoader(chunk_dataset, batch_size=query_loader.batch_size, shuffle=False, num_workers=query_loader.num_workers)
161
+
162
+ chunk_images, chunk_texts, chunk_filenames = process_chunk(fabric, model, chunk_loader)
163
+
164
+ all_images.append(chunk_images)
165
+ all_texts.append(chunk_texts)
166
+ all_filenames.extend(chunk_filenames)
167
+
168
+ torch.cuda.empty_cache()
169
+
170
+ all_images = torch.cat(all_images)
171
+ all_texts = torch.cat(all_texts)
172
+
173
+ similarity = compute_similarity(all_images, all_texts)
174
+
175
+ image_to_segs = query_loader.dataset.image_to_segs
176
+ org_images = query_loader.dataset.org_images
177
+ filename_to_caption = query_loader.dataset.filename_to_caption
178
+ mAP_i2t = 0.0
179
+ mAP_t2i = 0.0
180
+
181
+ # Image to Text
182
+ for i, filename in enumerate(all_filenames):
183
+ relevant_captions = get_relevant_captions(filename, image_to_segs, org_images, filename_to_caption)
184
+
185
+ i2t_ranks = torch.argsort(similarity[i], descending=True).tolist()
186
+ i2t_relevant = [idx for idx, fn in enumerate(all_filenames) if filename_to_caption[fn] in relevant_captions]
187
+ mAP_i2t += compute_ap(i2t_ranks, i2t_relevant, k)
188
+
189
+ # Text to Image
190
+ unique_captions = set(filename_to_caption.values())
191
+ for caption in unique_captions:
192
+ relevant_items = get_relevant_items_for_text(caption, all_filenames, image_to_segs, org_images, filename_to_caption)
193
+
194
+ caption_index = all_filenames.index(next(filename for filename, cap in filename_to_caption.items() if cap == caption))
195
+ t2i_ranks = torch.argsort(similarity[:, caption_index], descending=True).tolist()
196
+ mAP_t2i += compute_ap(t2i_ranks, relevant_items, k)
197
+
198
+ mAP_i2t /= len(all_filenames)
199
+ mAP_t2i /= len(unique_captions)
200
+
201
+ fabric.print(f"mAP@{k if k else 'all'} - Text-to-Image: {mAP_t2i:.4f} & Image-to-Text: {mAP_i2t:.4f}")
202
+
203
+
204
+ def main(args):
205
+ fabric = L.Fabric(
206
+ accelerator="cuda",
207
+ devices=devices,
208
+ precision="bf16-mixed"
209
+ )
210
+ fabric.launch()
211
+ fabric.seed_everything(1337 + fabric.global_rank)
212
+
213
+ if args.model == 'L-336':
214
+ args.model = 'openai/clip-vit-large-patch14-336'
215
+ elif args.model == 'L':
216
+ args.model = 'openai/clip-vit-large-patch14'
217
+ elif args.model == 'B':
218
+ args.model = 'openai/clip-vit-base-patch16'
219
+ elif args.model == 'G':
220
+ args.model = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
221
+
222
+ with fabric.device:
223
+ processor = transformers.AutoProcessor.from_pretrained(args.model)
224
+ model = transformers.AutoModel.from_pretrained(args.model).bfloat16()
225
+ longclip_pos_embeddings(model, args.new_max_token)
226
+ model.load_state_dict(torch.load(args.ckpt), strict=False)
227
+
228
+ if args.dataset == 'docci':
229
+ query_list = 'datasets/docci_test_joint_sim_max_1:1.json'
230
+ elif args.dataset == 'DCI':
231
+ query_list = 'datasets/DCI_test_joint_sim_max_1:1.json'
232
+
233
+ with open(query_list) as f:
234
+ query_list = json.loads(f.read())
235
+
236
+ args.query_list = query_list
237
+
238
+ query_dataset = QueryLoader(query_list, processor)
239
+ query_loader = DataLoader(query_dataset, num_workers=num_workers, batch_size=micro_batch_size, shuffle=False, drop_last=False, pin_memory=False)
240
+ query_loader = fabric.setup_dataloaders(query_loader)
241
+
242
+ model.eval().to(fabric.device)
243
+ test(fabric, model, query_loader, args.k)
244
+
245
+ if __name__ == "__main__":
246
+ torch.set_float32_matmul_precision("high")
247
+
248
+ parser = argparse.ArgumentParser()
249
+ parser.add_argument("--dataset", type=str, default='docci')
250
+ parser.add_argument('--new_max_token', default=248, type=int)
251
+ parser.add_argument("--model", type=str, default='B')
252
+ parser.add_argument("--ckpt", type=str, default='')
253
+ parser.add_argument("--k", type=int, default=10, help="Limit rank calculation to top K results. Use None for all ranks.")
254
+ args = parser.parse_args()
255
+
256
+ main(args)
GOAL_github/retrieval_goal.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ import os
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = "3"
5
+ import sys
6
+ from pathlib import Path
7
+ import lightning as L
8
+ import numpy as np
9
+ import torch
10
+ import transformers
11
+ from torch.utils.data import DataLoader
12
+ from torch.utils.data import Dataset
13
+ from PIL import Image
14
+ import torch.nn.functional as F
15
+ from utils.func import *
16
+ from utils.transforms import *
17
+
18
+ # Hyperparameters
19
+ micro_batch_size = 32
20
+ devices = 1
21
+ num_workers = 1
22
+
23
+ class QueryLoader(Dataset):
24
+ def __init__(self, data_list, processor):
25
+ self.data_list = data_list
26
+ self.processor = processor
27
+
28
+ def __len__(self):
29
+ return len(self.data_list)
30
+
31
+ def _load_image(self, id: int):
32
+ return Image.open(self.data_list[id]["filename"]).convert("RGB")
33
+
34
+ def _load_target(self, id: int):
35
+ return self.data_list[id]["caption"]
36
+
37
+ def __getitem__(self, idx):
38
+ if torch.is_tensor(idx):
39
+ idx = idx.tolist()
40
+ image = self._load_image(idx)
41
+ caption = self._load_target(idx)
42
+ data = self.processor(images=image, text=caption, return_tensors="pt", truncation = True, padding = "max_length", max_length=args.new_max_token)
43
+
44
+ return data.pixel_values[0], data.input_ids[0]
45
+
46
+
47
+ def main(args):
48
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0" # GPU 1번 인덱스만 사용
49
+ fabric = L.Fabric(
50
+ accelerator="cuda",
51
+ devices=devices,
52
+ precision="bf16-mixed" # "32"에서 "bf16-mixed"로 변경
53
+ )
54
+ fabric.launch()
55
+ fabric.seed_everything(1337 + fabric.global_rank)
56
+
57
+ if args.model=='L-336':
58
+ args.model = 'openai/clip-vit-large-patch14-336'
59
+ elif args.model=='L':
60
+ args.model = 'openai/clip-vit-large-patch14'
61
+ elif args.model=='B':
62
+ args.model = 'openai/clip-vit-base-patch16'
63
+ elif args.model=='G':
64
+ args.model = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
65
+
66
+ with fabric.device:
67
+ processor = transformers.AutoProcessor.from_pretrained(args.model)
68
+ model = transformers.AutoModel.from_pretrained(args.model).bfloat16()
69
+ longclip_pos_embeddings(model, args.new_max_token)
70
+ if args.ckpt: # ckpt가 제공된 경우에만 로드
71
+ model.load_state_dict(torch.load(args.ckpt), strict=False)
72
+
73
+ if args.dataset == 'docci':
74
+ query_list = 'datasets/docci_test.json'
75
+ elif args.dataset =='coco':
76
+ query_list = 'datasets/coco_test.json'
77
+ elif args.dataset =='flickr30k':
78
+ query_list = 'datasets/flickr30k_test.json'
79
+ elif args.dataset =='DCI':
80
+ query_list = 'datasets/DCI_test.json'
81
+ elif args.dataset =='urban':
82
+ query_list = 'datasets/urban_dataset_test.json'
83
+ elif args.dataset =='sharegpt4v':
84
+ query_list = 'datasets/sharegpt4v_test.json'
85
+
86
+ with open(query_list) as f:
87
+ query_list = json.loads(f.read())
88
+
89
+ args.query_list = query_list
90
+
91
+ query_dataset = QueryLoader(query_list, processor)
92
+ query_loader = DataLoader(query_dataset, num_workers=num_workers, batch_size=micro_batch_size, shuffle=False, drop_last=False, pin_memory=False)
93
+ query_loader = fabric.setup_dataloaders(query_loader)
94
+
95
+ model.eval().to(fabric.device)
96
+ test(fabric, model, query_loader)
97
+
98
+ def compute_AP_and_recall_at_Ks(similarity, label_matrix, Ks):
99
+ # Sort gallery indices based on similarity
100
+ sorted_indices = torch.argsort(similarity, descending=True)
101
+ # Initialize results
102
+ results = {K: {'AP': 0.0, 'recall': 0.0, 'relevant_items': 0} for K in Ks}
103
+ total_relevant_items = label_matrix.sum().item()
104
+ for i, idx in enumerate(sorted_indices):
105
+ if label_matrix[idx]:
106
+ for K in Ks:
107
+ if i < K:
108
+ results[K]['relevant_items'] += 1
109
+ precision = results[K]['relevant_items'] / (i + 1)
110
+ results[K]['AP'] += precision
111
+ results[K]['recall'] = results[K]['relevant_items'] / total_relevant_items
112
+ for K in Ks:
113
+ results[K]['AP'] /= total_relevant_items
114
+ return results
115
+
116
+
117
+ @torch.no_grad()
118
+ def test(fabric: L.Fabric, model: torch.nn.Module, query_loader) -> torch.Tensor:
119
+ fabric.print("Testing ...")
120
+
121
+ images = torch.tensor([], dtype=torch.float32).to(fabric.device)
122
+ texts = torch.tensor([], dtype=torch.float32).to(fabric.device)
123
+
124
+ for samples in query_loader:
125
+ image, text = samples
126
+
127
+ x = model(pixel_values=image, input_ids=text)
128
+
129
+ x_i = F.normalize(x.image_embeds)
130
+ x_t = F.normalize(x.text_embeds)
131
+
132
+ images = torch.cat((images,x_i), dim=0)
133
+ texts = torch.cat((texts,x_t), dim=0)
134
+
135
+ # Calculate cosine similarity
136
+ similarity = torch.mm(images, texts.t())
137
+ # Image to Text (I2T)
138
+ sorted_indices_i2t = torch.argsort(similarity, descending=True)
139
+ correct_indices = torch.arange(images.shape[0]).to(fabric.device)
140
+ ranks_i2t = (sorted_indices_i2t == correct_indices[:, None]).nonzero(as_tuple=True)[1]
141
+ # Text to Image (T2I)
142
+ sorted_indices_t2i = torch.argsort(similarity.t(), descending=True)
143
+ ranks_t2i = (sorted_indices_t2i == correct_indices[:, None]).nonzero(as_tuple=True)[1]
144
+ # Calculate recall at different ranks for I2T
145
+ recall_i2t_1 = (ranks_i2t < 1).float().mean().item() * 100
146
+ recall_i2t_5 = (ranks_i2t < 5).float().mean().item() * 100
147
+ recall_i2t_25 = (ranks_i2t < 25).float().mean().item() * 100
148
+ recall_i2t_50 = (ranks_i2t < 50).float().mean().item() * 100
149
+ # Calculate recall at different ranks for T2I
150
+ recall_t2i_1 = (ranks_t2i < 1).float().mean().item() * 100
151
+ recall_t2i_5 = (ranks_t2i < 5).float().mean().item() * 100
152
+ recall_t2i_25 = (ranks_t2i < 25).float().mean().item() * 100
153
+ recall_t2i_50 = (ranks_t2i < 50).float().mean().item() * 100
154
+ # Print recall percentages for T2I
155
+ fabric.print(f"Text-to-Image: {recall_t2i_1:.2f} & {recall_t2i_5:.2f} & {recall_t2i_25:.2f} & {recall_t2i_50:.2f}")
156
+ # Print recall percentages for I2T
157
+ fabric.print(f"Image-to-Text: {recall_i2t_1:.2f} & {recall_i2t_5:.2f} & {recall_i2t_25:.2f} & {recall_i2t_50:.2f}")
158
+
159
+
160
+
161
+ if __name__ == "__main__":
162
+ torch.set_float32_matmul_precision("high")
163
+
164
+ parser = argparse.ArgumentParser()
165
+ parser.add_argument("--dataset", type=str, default='urban')
166
+ parser.add_argument('--new_max_token', default=248, type=int)
167
+ parser.add_argument("--model", type=str, default='B')
168
+ parser.add_argument("--ckpt", type=str, default='')
169
+ args = parser.parse_args()
170
+
171
+ main(args)
GOAL_github/utils/__pycache__/easydict.cpython-39.pyc ADDED
Binary file (3.64 kB). View file
 
GOAL_github/utils/__pycache__/func.cpython-310.pyc ADDED
Binary file (3.92 kB). View file
 
GOAL_github/utils/__pycache__/func.cpython-311.pyc ADDED
Binary file (9.11 kB). View file
 
GOAL_github/utils/__pycache__/func.cpython-39.pyc ADDED
Binary file (3.86 kB). View file
 
GOAL_github/utils/__pycache__/randaugment.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
GOAL_github/utils/__pycache__/randaugment.cpython-311.pyc ADDED
Binary file (18.5 kB). View file
 
GOAL_github/utils/__pycache__/randaugment.cpython-39.pyc ADDED
Binary file (10.5 kB). View file
 
GOAL_github/utils/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (3.55 kB). View file
 
GOAL_github/utils/__pycache__/transforms.cpython-311.pyc ADDED
Binary file (5.44 kB). View file
 
GOAL_github/utils/__pycache__/transforms.cpython-39.pyc ADDED
Binary file (3.61 kB). View file
 
GOAL_github/utils/func.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ from PIL import Image
4
+ import numpy as np
5
+
6
+ def interpolate_pos_embeddings(model, new_image_size):
7
+ vision_model = model.vision_model
8
+ patch_size = vision_model.config.patch_size
9
+ num_patches = (new_image_size // patch_size) ** 2 + 1
10
+ # Extract and interpolate positional embeddings
11
+ pos_embeddings = vision_model.embeddings.position_embedding.weight
12
+ pos_embeddings = pos_embeddings.unsqueeze(0).permute(0, 2, 1) # Convert to 1xCxN format
13
+ pos_embeddings = torch.nn.functional.interpolate(
14
+ pos_embeddings, size=(num_patches), mode='nearest'
15
+ ).squeeze(0).permute(1, 0) # Convert back to NxC format
16
+ pos_embeddings = pos_embeddings.contiguous() # Ensure contiguous
17
+ vision_model.embeddings.position_embedding.weight = torch.nn.Parameter(pos_embeddings)
18
+ # Set position_ids
19
+ if hasattr(vision_model.embeddings, 'position_ids'):
20
+ vision_model.embeddings.position_ids = torch.arange(0, num_patches).unsqueeze(0)
21
+ else:
22
+ vision_model.register_buffer('position_ids', torch.arange(0, num_patches).unsqueeze(0))
23
+
24
+ def interpolate_text_pos_embeddings(model, new_max_token):
25
+ text_model = model.text_model
26
+ # Extract and interpolate positional embeddings
27
+ pos_embeddings = text_model.embeddings.position_embedding.weight
28
+ pos_embeddings = pos_embeddings.unsqueeze(0).permute(0, 2, 1) # Convert to 1xCxN format
29
+
30
+ # Interpolate the position embeddings to the new maximum token length
31
+ pos_embeddings = torch.nn.functional.interpolate(
32
+ pos_embeddings, size=(new_max_token), mode='nearest'
33
+ ).squeeze(0).permute(1, 0) # Convert back to NxC format
34
+
35
+ pos_embeddings = pos_embeddings.contiguous() # Ensure contiguous
36
+ text_model.embeddings.position_embedding.weight = torch.nn.Parameter(pos_embeddings)
37
+
38
+ # Set position_ids if the model uses them
39
+ if hasattr(text_model.embeddings, 'position_ids'):
40
+ text_model.embeddings.position_ids = torch.arange(0, new_max_token).unsqueeze(0)
41
+ else:
42
+ text_model.register_buffer('position_ids', torch.arange(0, new_max_token).unsqueeze(0))
43
+
44
+ def longclip_pos_embeddings(model, new_max_token):
45
+ text_model = model.text_model
46
+ # Extract positional embeddings
47
+ pos_embeddings_pre = text_model.embeddings.position_embedding.weight
48
+ length, dim = pos_embeddings_pre.shape
49
+ keep_len = 20
50
+ new_length = 4*length - 3*keep_len
51
+ if new_length < new_max_token:
52
+ raise ValueError("new_max_token is too large")
53
+ pos_embeddings_new = torch.zeros([new_max_token, dim], dtype=pos_embeddings_pre.dtype)
54
+ for i in range(keep_len):
55
+ pos_embeddings_new[i] = pos_embeddings_pre[i]
56
+ for i in range(length-1-keep_len):
57
+ pos_embeddings_new[4*i + keep_len] = pos_embeddings_pre[i + keep_len]
58
+ pos_embeddings_new[4*i + 1 + keep_len] = 3*pos_embeddings_pre[i + keep_len]/4 + 1*pos_embeddings_pre[i+1+keep_len]/4
59
+ pos_embeddings_new[4*i + 2+keep_len] = 2*pos_embeddings_pre[i+keep_len]/4 + 2*pos_embeddings_pre[i+1+keep_len]/4
60
+ pos_embeddings_new[4*i + 3+keep_len] = 1*pos_embeddings_pre[i+keep_len]/4 + 3*pos_embeddings_pre[i+1+keep_len]/4
61
+ pos_embeddings_new[4*length -3*keep_len - 4] = pos_embeddings_pre[length-1] + 0*(pos_embeddings_pre[length-1] - pos_embeddings_pre[length-2])/4
62
+ pos_embeddings_new[4*length -3*keep_len - 3] = pos_embeddings_pre[length-1] + 1*(pos_embeddings_pre[length-1] - pos_embeddings_pre[length-2])/4
63
+ pos_embeddings_new[4*length -3*keep_len - 2] = pos_embeddings_pre[length-1] + 2*(pos_embeddings_pre[length-1] - pos_embeddings_pre[length-2])/4
64
+ pos_embeddings_new[4*length -3*keep_len - 1] = pos_embeddings_pre[length-1] + 3*(pos_embeddings_pre[length-1] - pos_embeddings_pre[length-2])/4
65
+ text_model.embeddings.position_embedding.weight = torch.nn.Parameter(pos_embeddings_new)
66
+ # Set position_ids if the model uses them
67
+ if hasattr(text_model.embeddings, 'position_ids'):
68
+ text_model.embeddings.position_ids = torch.arange(0, new_max_token).unsqueeze(0)
69
+ else:
70
+ text_model.register_buffer('position_ids', torch.arange(0, new_max_token).unsqueeze(0))
71
+
72
+
73
+ def average_pool(last_hidden_states, attention_mask):
74
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
75
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
76
+
77
+ def last_token_pool(last_hidden_states, attention_mask):
78
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
79
+ if left_padding:
80
+ return last_hidden_states[:, -1]
81
+ else:
82
+ sequence_lengths = attention_mask.sum(dim=1) - 1
83
+ batch_size = last_hidden_states.shape[0]
84
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
85
+
86
+ def batch_align(fabric, x):
87
+ x = fabric.all_gather(x, sync_grads=True)
88
+ return x.view(x.shape[0]*x.shape[1], -1)
89
+
90
+ cls_criterion = torch.nn.CrossEntropyLoss()
91
+
92
+ def clip_loss(logits):
93
+ gt = torch.arange(len(logits),dtype=torch.long, device=logits.device)
94
+ return (cls_criterion(logits, gt) + cls_criterion(logits.t(), gt))/2.0
95
+
96
+ def print_trainable_parameters(fabric, model):
97
+ trainable_params = 0
98
+ all_param = 0
99
+ for _, param in model.named_parameters():
100
+ all_param += param.numel()
101
+ if param.requires_grad:
102
+ trainable_params += param.numel()
103
+ fabric.print(
104
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
105
+ )
106
+ fabric.print('Memory load of model: {} bytes'.format(torch.cuda.memory_allocated()))
GOAL_github/utils/randaugment.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code from salesforce/BLIP
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+
7
+ ## aug functions
8
+ def identity_func(img):
9
+ return img
10
+
11
+
12
+ def autocontrast_func(img, cutoff=0):
13
+ """
14
+ same output as PIL.ImageOps.autocontrast
15
+ """
16
+ n_bins = 256
17
+
18
+ def tune_channel(ch):
19
+ n = ch.size
20
+ cut = cutoff * n // 100
21
+ if cut == 0:
22
+ high, low = ch.max(), ch.min()
23
+ else:
24
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
25
+ low = np.argwhere(np.cumsum(hist) > cut)
26
+ low = 0 if low.shape[0] == 0 else low[0]
27
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
28
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
29
+ if high <= low:
30
+ table = np.arange(n_bins)
31
+ else:
32
+ scale = (n_bins - 1) / (high - low)
33
+ offset = np.multiply(low, -scale)
34
+ table = np.arange(n_bins) * scale + offset
35
+ table[table < 0] = 0
36
+ table[table > n_bins - 1] = n_bins - 1
37
+ table = table.clip(0, 255).astype(np.uint8)
38
+ return table[ch]
39
+
40
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
41
+ out = cv2.merge(channels)
42
+ return out
43
+
44
+
45
+ def equalize_func(img):
46
+ """
47
+ same output as PIL.ImageOps.equalize
48
+ PIL's implementation is different from cv2.equalize
49
+ """
50
+ n_bins = 256
51
+
52
+ def tune_channel(ch):
53
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
54
+ non_zero_hist = hist[hist != 0].reshape(-1)
55
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
56
+ if step == 0:
57
+ return ch
58
+ n = np.empty_like(hist)
59
+ n[0] = step // 2
60
+ n[1:] = hist[:-1]
61
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
62
+ return table[ch]
63
+
64
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
65
+ out = cv2.merge(channels)
66
+ return out
67
+
68
+
69
+ def rotate_func(img, degree, fill=(0, 0, 0)):
70
+ """
71
+ like PIL, rotate by degree, not radians
72
+ """
73
+ H, W = img.shape[0], img.shape[1]
74
+ center = W / 2, H / 2
75
+ M = cv2.getRotationMatrix2D(center, degree, 1)
76
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
77
+ return out
78
+
79
+
80
+ def solarize_func(img, thresh=128):
81
+ """
82
+ same output as PIL.ImageOps.posterize
83
+ """
84
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
85
+ table = table.clip(0, 255).astype(np.uint8)
86
+ out = table[img]
87
+ return out
88
+
89
+
90
+ def color_func(img, factor):
91
+ """
92
+ same output as PIL.ImageEnhance.Color
93
+ """
94
+ ## implementation according to PIL definition, quite slow
95
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
96
+ # out = blend(degenerate, img, factor)
97
+ # M = (
98
+ # np.eye(3) * factor
99
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
100
+ # )[np.newaxis, np.newaxis, :]
101
+ M = np.array(
102
+ [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]],
103
+ dtype=np.float32,
104
+ ) * factor + np.array([[0.114], [0.587], [0.299]], dtype=np.float32)
105
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
106
+ return out
107
+
108
+
109
+ def contrast_func(img, factor):
110
+ """
111
+ same output as PIL.ImageEnhance.Contrast
112
+ """
113
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
114
+ table = (
115
+ np.array([(el - mean) * factor + mean for el in range(256)])
116
+ .clip(0, 255)
117
+ .astype(np.uint8)
118
+ )
119
+ out = table[img]
120
+ return out
121
+
122
+
123
+ def brightness_func(img, factor):
124
+ """
125
+ same output as PIL.ImageEnhance.Contrast
126
+ """
127
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
128
+ out = table[img]
129
+ return out
130
+
131
+
132
+ def sharpness_func(img, factor):
133
+ """
134
+ The differences the this result and PIL are all on the 4 boundaries, the center
135
+ areas are same
136
+ """
137
+ kernel = np.ones((3, 3), dtype=np.float32)
138
+ kernel[1][1] = 5
139
+ kernel /= 13
140
+ degenerate = cv2.filter2D(img, -1, kernel)
141
+ if factor == 0.0:
142
+ out = degenerate
143
+ elif factor == 1.0:
144
+ out = img
145
+ else:
146
+ out = img.astype(np.float32)
147
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
148
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
149
+ out = out.astype(np.uint8)
150
+ return out
151
+
152
+
153
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
154
+ H, W = img.shape[0], img.shape[1]
155
+ M = np.array([[1, factor, 0], [0, 1, 0]], dtype=np.float32)
156
+ out = cv2.warpAffine(
157
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
158
+ ).astype(np.uint8)
159
+ return out
160
+
161
+
162
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
163
+ """
164
+ same output as PIL.Image.transform
165
+ """
166
+ H, W = img.shape[0], img.shape[1]
167
+ M = np.array([[1, 0, -offset], [0, 1, 0]], dtype=np.float32)
168
+ out = cv2.warpAffine(
169
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
170
+ ).astype(np.uint8)
171
+ return out
172
+
173
+
174
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
175
+ """
176
+ same output as PIL.Image.transform
177
+ """
178
+ H, W = img.shape[0], img.shape[1]
179
+ M = np.array([[1, 0, 0], [0, 1, -offset]], dtype=np.float32)
180
+ out = cv2.warpAffine(
181
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
182
+ ).astype(np.uint8)
183
+ return out
184
+
185
+
186
+ def posterize_func(img, bits):
187
+ """
188
+ same output as PIL.ImageOps.posterize
189
+ """
190
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
191
+ return out
192
+
193
+
194
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
195
+ H, W = img.shape[0], img.shape[1]
196
+ M = np.array([[1, 0, 0], [factor, 1, 0]], dtype=np.float32)
197
+ out = cv2.warpAffine(
198
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
199
+ ).astype(np.uint8)
200
+ return out
201
+
202
+
203
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
204
+ replace = np.array(replace, dtype=np.uint8)
205
+ H, W = img.shape[0], img.shape[1]
206
+ rh, rw = np.random.random(2)
207
+ pad_size = pad_size // 2
208
+ ch, cw = int(rh * H), int(rw * W)
209
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
210
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
211
+ out = img.copy()
212
+ out[x1:x2, y1:y2, :] = replace
213
+ return out
214
+
215
+
216
+ ### level to args
217
+ def enhance_level_to_args(MAX_LEVEL):
218
+ def level_to_args(level):
219
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
220
+
221
+ return level_to_args
222
+
223
+
224
+ def shear_level_to_args(MAX_LEVEL, replace_value):
225
+ def level_to_args(level):
226
+ level = (level / MAX_LEVEL) * 0.3
227
+ if np.random.random() > 0.5:
228
+ level = -level
229
+ return (level, replace_value)
230
+
231
+ return level_to_args
232
+
233
+
234
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
235
+ def level_to_args(level):
236
+ level = (level / MAX_LEVEL) * float(translate_const)
237
+ if np.random.random() > 0.5:
238
+ level = -level
239
+ return (level, replace_value)
240
+
241
+ return level_to_args
242
+
243
+
244
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
245
+ def level_to_args(level):
246
+ level = int((level / MAX_LEVEL) * cutout_const)
247
+ return (level, replace_value)
248
+
249
+ return level_to_args
250
+
251
+
252
+ def solarize_level_to_args(MAX_LEVEL):
253
+ def level_to_args(level):
254
+ level = int((level / MAX_LEVEL) * 256)
255
+ return (level,)
256
+
257
+ return level_to_args
258
+
259
+
260
+ def none_level_to_args(level):
261
+ return ()
262
+
263
+
264
+ def posterize_level_to_args(MAX_LEVEL):
265
+ def level_to_args(level):
266
+ level = int((level / MAX_LEVEL) * 4)
267
+ return (level,)
268
+
269
+ return level_to_args
270
+
271
+
272
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
273
+ def level_to_args(level):
274
+ level = (level / MAX_LEVEL) * 30
275
+ if np.random.random() < 0.5:
276
+ level = -level
277
+ return (level, replace_value)
278
+
279
+ return level_to_args
280
+
281
+
282
+ func_dict = {
283
+ "Identity": identity_func,
284
+ "AutoContrast": autocontrast_func,
285
+ "Equalize": equalize_func,
286
+ "Rotate": rotate_func,
287
+ "Solarize": solarize_func,
288
+ "Color": color_func,
289
+ "Contrast": contrast_func,
290
+ "Brightness": brightness_func,
291
+ "Sharpness": sharpness_func,
292
+ "ShearX": shear_x_func,
293
+ "TranslateX": translate_x_func,
294
+ "TranslateY": translate_y_func,
295
+ "Posterize": posterize_func,
296
+ "ShearY": shear_y_func,
297
+ }
298
+
299
+ translate_const = 10
300
+ MAX_LEVEL = 10
301
+ replace_value = (128, 128, 128)
302
+ arg_dict = {
303
+ "Identity": none_level_to_args,
304
+ "AutoContrast": none_level_to_args,
305
+ "Equalize": none_level_to_args,
306
+ "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
307
+ "Solarize": solarize_level_to_args(MAX_LEVEL),
308
+ "Color": enhance_level_to_args(MAX_LEVEL),
309
+ "Contrast": enhance_level_to_args(MAX_LEVEL),
310
+ "Brightness": enhance_level_to_args(MAX_LEVEL),
311
+ "Sharpness": enhance_level_to_args(MAX_LEVEL),
312
+ "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
313
+ "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
314
+ "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
315
+ "Posterize": posterize_level_to_args(MAX_LEVEL),
316
+ "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
317
+ }
318
+
319
+
320
+ class RandomAugment(object):
321
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
322
+ self.N = N
323
+ self.M = M
324
+ self.isPIL = isPIL
325
+ if augs:
326
+ self.augs = augs
327
+ else:
328
+ self.augs = list(arg_dict.keys())
329
+
330
+ def get_random_ops(self):
331
+ sampled_ops = np.random.choice(self.augs, self.N)
332
+ return [(op, 0.5, self.M) for op in sampled_ops]
333
+
334
+ def __call__(self, img):
335
+ if self.isPIL:
336
+ img = np.array(img)
337
+ ops = self.get_random_ops()
338
+ for name, prob, level in ops:
339
+ if np.random.random() > prob:
340
+ continue
341
+ args = arg_dict[name](level)
342
+ img = func_dict[name](img, *args)
343
+ return img
344
+
345
+
346
+ if __name__ == "__main__":
347
+ a = RandomAugment()
348
+ img = np.random.randn(32, 32, 3)
349
+ a(img)
GOAL_github/utils/transforms.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ from torchvision.transforms.functional import InterpolationMode
3
+ from torchvision.transforms import Compose, CenterCrop, ToTensor, Normalize, Resize
4
+ from utils.randaugment import RandomAugment
5
+ import torch
6
+ import torchvision.transforms.functional as FT
7
+ import math
8
+ import torch.nn.functional as F
9
+
10
+ normalize = transforms.Normalize(
11
+ (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
12
+ )
13
+
14
+ class transform_train:
15
+ def __init__(self, image_size=384, min_scale=0.5):
16
+ self.transform = transforms.Compose(
17
+ [
18
+ transforms.RandomResizedCrop(
19
+ image_size,
20
+ scale=(min_scale, 1.0),
21
+ interpolation=InterpolationMode.BICUBIC,
22
+ ),
23
+ transforms.RandomHorizontalFlip(),
24
+ transforms.ToTensor(),
25
+ normalize,
26
+ ]
27
+ )
28
+
29
+ def __call__(self, img):
30
+ return self.transform(img)
31
+
32
+
33
+ # class transform_train:
34
+ # def __init__(self, image_size=384, min_scale=0.5):
35
+ # self.transform = transforms.Compose(
36
+ # [
37
+ # transforms.RandomResizedCrop(
38
+ # image_size,
39
+ # scale=(min_scale, 1.0),
40
+ # interpolation=InterpolationMode.BICUBIC,
41
+ # ),
42
+ # transforms.RandomHorizontalFlip(),
43
+ # RandomAugment(
44
+ # 2,
45
+ # 5,
46
+ # isPIL=True,
47
+ # augs=[
48
+ # "Identity",
49
+ # "AutoContrast",
50
+ # "Brightness",
51
+ # "Sharpness",
52
+ # "Equalize",
53
+ # "ShearX",
54
+ # "ShearY",
55
+ # "TranslateX",
56
+ # "TranslateY",
57
+ # "Rotate",
58
+ # ],
59
+ # ),
60
+ # transforms.ToTensor(),
61
+ # normalize,
62
+ # ]
63
+ # )
64
+
65
+ # def __call__(self, img):
66
+ # return self.transform(img)
67
+
68
+
69
+ class transform_test(transforms.Compose):
70
+ def __init__(self, image_size=384):
71
+ self.transform = transforms.Compose(
72
+ [
73
+ transforms.Resize(
74
+ (image_size, image_size),
75
+ interpolation=InterpolationMode.BICUBIC,
76
+ ),
77
+ transforms.ToTensor(),
78
+ normalize,
79
+ ]
80
+ )
81
+
82
+ def __call__(self, img):
83
+ return self.transform(img)
84
+
85
+
86
+
87
+ class TargetPad:
88
+ """
89
+ If an image aspect ratio is above a target ratio, pad the image to match such target ratio.
90
+ For more details see Baldrati et al. 'Effective conditioned and composed image retrieval combining clip-based features.' Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (2022).
91
+ """
92
+
93
+ def __init__(self, target_ratio: float, size: int):
94
+ """
95
+ :param target_ratio: target ratio
96
+ :param size: preprocessing output dimension
97
+ """
98
+ self.size = size
99
+ self.target_ratio = target_ratio
100
+
101
+ def __call__(self, image):
102
+ w, h = image.size
103
+ actual_ratio = max(w, h) / min(w, h)
104
+ if actual_ratio < self.target_ratio: # check if the ratio is above or below the target ratio
105
+ return image
106
+ scaled_max_wh = max(w, h) / self.target_ratio # rescale the pad to match the target ratio
107
+ hp = max(int((scaled_max_wh - w) / 2), 0)
108
+ vp = max(int((scaled_max_wh - h) / 2), 0)
109
+ padding = [hp, vp, hp, vp]
110
+ return FT.pad(image, padding, 0, 'constant')
111
+
112
+
113
+ def _convert_image_to_rgb(image):
114
+ return image.convert("RGB")
115
+
116
+ def targetpad_transform(target_ratio: float, dim: int) -> torch.Tensor:
117
+ """
118
+ CLIP-like preprocessing transform computed after using TargetPad pad
119
+ :param target_ratio: target ratio for TargetPad
120
+ :param dim: image output dimension
121
+ :return: CLIP-like torchvision Compose transform
122
+ """
123
+ return Compose([
124
+ TargetPad(target_ratio, dim),
125
+ Resize(dim, interpolation=InterpolationMode.BICUBIC),
126
+ CenterCrop(dim),
127
+ _convert_image_to_rgb,
128
+ ToTensor(),
129
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
130
+ ])
GOAL_github/visualization/visualization_attentionmap_longtestset.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ import os
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = "3"
5
+ import sys
6
+ from pathlib import Path
7
+ import lightning as L
8
+ import numpy as np
9
+ import torch
10
+ import transformers
11
+ from torch.utils.data import DataLoader
12
+ from torch.utils.data import Dataset
13
+ from PIL import Image
14
+ import torch.nn.functional as F
15
+ from utils.func import *
16
+ from utils.transforms import *
17
+ import matplotlib.pyplot as plt
18
+ import cv2
19
+
20
+ # Hyperparameters
21
+ micro_batch_size = 1
22
+ devices = 1
23
+ num_workers = 1
24
+
25
+ class ImageAttentionLoader(Dataset):
26
+ def __init__(self, image_path, processor):
27
+ self.image_path = image_path
28
+ self.processor = processor
29
+
30
+ def __len__(self):
31
+ return 1
32
+
33
+ def _load_image(self):
34
+ return Image.open(self.image_path).convert("RGB")
35
+
36
+ def __getitem__(self, idx):
37
+ image = self._load_image()
38
+ self.original_size = image.size
39
+ # Process image without return_tensors
40
+ data = self.processor(images=image)
41
+ # Convert to tensor manually
42
+ pixel_values = torch.tensor(data['pixel_values'])
43
+ # Ensure correct shape [C, H, W]
44
+ if len(pixel_values.shape) > 3:
45
+ pixel_values = pixel_values.squeeze()
46
+ return pixel_values, np.array(image)
47
+
48
+ def get_attention_maps(model, image_input):
49
+ print("Input shape before processing:", image_input.shape)
50
+
51
+ # Ensure correct shape and device
52
+ if len(image_input.shape) == 3:
53
+ image_input = image_input.unsqueeze(0) # Add batch dimension
54
+
55
+ image_input = image_input.to(model.device)
56
+ print("Input shape after processing:", image_input.shape)
57
+
58
+ outputs = model.vision_model(pixel_values=image_input, output_attentions=True)
59
+
60
+ # 마지막 레이어의 어텐션 맵 가져오기
61
+ attention_maps = outputs.attentions[-1]
62
+ print("Attention maps shape:", attention_maps.shape)
63
+
64
+ # 어텐션 맵의 크기 조정 (헤드 평균)
65
+ attention_maps = attention_maps.mean(dim=1)
66
+
67
+ # CLS 토큰을 제외한 패치들의 어텐션만 사용
68
+ attention_maps = attention_maps[:, 1:, 1:]
69
+
70
+ # CPU로 이동하고 float32로 변환
71
+ attention_maps = attention_maps.squeeze().cpu().float()
72
+ print("Final attention maps shape:", attention_maps.shape)
73
+
74
+ return attention_maps
75
+
76
+ def visualize_attention(image, attention_maps, output_path, model_name):
77
+ # 이미지 패치 크기 계산
78
+ if 'patch14' in model_name:
79
+ patch_size = 14
80
+ elif 'patch32' in model_name:
81
+ patch_size = 32
82
+ else:
83
+ patch_size = 16
84
+
85
+ # 어텐션 맵 크기 조정
86
+ h = int(np.sqrt(attention_maps.shape[0]))
87
+ attention_maps = attention_maps.reshape(h, h)
88
+
89
+ # 어텐션 맵을 원본 이미지 크기로 업샘플링
90
+ attention_maps = F.interpolate(
91
+ attention_maps.unsqueeze(0).unsqueeze(0),
92
+ size=image.shape[:2],
93
+ mode='bicubic'
94
+ ).squeeze().numpy()
95
+
96
+ # 정규화
97
+ attention_maps = (attention_maps - attention_maps.min()) / (attention_maps.max() - attention_maps.min())
98
+
99
+ # 시각화
100
+ plt.figure(figsize=(15, 5))
101
+
102
+ # 원본 이미지
103
+ plt.subplot(1, 3, 1)
104
+ plt.imshow(image)
105
+ plt.title('Original Image')
106
+ plt.axis('off')
107
+
108
+ # 어텐션 맵
109
+ plt.subplot(1, 3, 2)
110
+ plt.imshow(attention_maps, cmap='jet')
111
+ plt.title('Self-Attention Map')
112
+ plt.axis('off')
113
+
114
+ # 오버레이
115
+ plt.subplot(1, 3, 3)
116
+ plt.imshow(image)
117
+ plt.imshow(attention_maps, cmap='jet', alpha=0.5)
118
+ plt.title('Overlay')
119
+ plt.axis('off')
120
+
121
+ plt.tight_layout()
122
+ plt.savefig(output_path, dpi=300, bbox_inches='tight')
123
+ plt.close()
124
+
125
+ def main(args):
126
+ os.environ["CUDA_VISIBLE_DEVICES"] = "3"
127
+ fabric = L.Fabric(
128
+ accelerator="cuda",
129
+ devices=devices,
130
+ precision="bf16-mixed"
131
+ )
132
+ fabric.launch()
133
+ fabric.seed_everything(1337 + fabric.global_rank)
134
+
135
+ if args.model=='L-336':
136
+ args.model = 'openai/clip-vit-large-patch14-336'
137
+ elif args.model=='L':
138
+ args.model = 'openai/clip-vit-large-patch14'
139
+ elif args.model=='B':
140
+ args.model = 'openai/clip-vit-base-patch32'
141
+ elif args.model=='G':
142
+ args.model = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
143
+
144
+ with fabric.device:
145
+ processor = transformers.AutoProcessor.from_pretrained(args.model)
146
+ model = transformers.AutoModel.from_pretrained(args.model, output_attentions=True).bfloat16()
147
+ longclip_pos_embeddings(model, args.new_max_token)
148
+ if args.ckpt:
149
+ model.load_state_dict(torch.load(args.ckpt), strict=False)
150
+
151
+ # 데이터 로더 생성
152
+ dataset = ImageAttentionLoader(args.image_path, processor)
153
+ dataloader = DataLoader(dataset, batch_size=micro_batch_size, shuffle=False)
154
+
155
+ # 모델 평가 모드로 설정
156
+ model.eval()
157
+ model = model.to(fabric.device)
158
+
159
+ # 어텐션 맵 생성 및 시각화
160
+ with torch.no_grad():
161
+ for batch in dataloader:
162
+ image_input, original_image = batch
163
+ print("Original input shape:", image_input.shape)
164
+ attention_maps = get_attention_maps(model, image_input)
165
+
166
+ # 모든 패치의 평균 어텐션 점수 계산
167
+ avg_attention = attention_maps.mean(dim=0)
168
+
169
+ visualize_attention(
170
+ original_image[0],
171
+ avg_attention,
172
+ args.output_path,
173
+ args.model
174
+ )
175
+ break
176
+
177
+ if __name__ == "__main__":
178
+ torch.set_float32_matmul_precision("high")
179
+
180
+ parser = argparse.ArgumentParser()
181
+ parser.add_argument("--image_path", type=str, required=True, help="Path to input image")
182
+ parser.add_argument("--output_path", type=str, required=True, help="Path to save visualization")
183
+ parser.add_argument('--new_max_token', default=248, type=int)
184
+ parser.add_argument("--model", type=str, default='L')
185
+ parser.add_argument("--ckpt", type=str, default='')
186
+ args = parser.parse_args()
187
+
188
+ main(args)
GOAL_github/visualization/visualization_retreival.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ import os
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = "3"
5
+ import sys
6
+ from pathlib import Path
7
+ import lightning as L
8
+ import numpy as np
9
+ import torch
10
+ import transformers
11
+ from torch.utils.data import DataLoader, Dataset
12
+ from PIL import Image
13
+ import torch.nn.functional as F
14
+ from utils.func import *
15
+ import matplotlib.pyplot as plt
16
+ import cv2
17
+ import textwrap
18
+
19
+ def collate_fn(batch):
20
+ # 배치에서 최대 길이 찾기
21
+ max_text_length = max(len(item['input_ids']) for item in batch)
22
+
23
+ # 배치의 모든 항목을 담을 리스트
24
+ pixel_values = []
25
+ input_ids = []
26
+ attention_masks = []
27
+ image_paths = []
28
+ captions = []
29
+
30
+ for item in batch:
31
+ # 이미지 처리
32
+ pixel_values.append(item['pixel_values'])
33
+
34
+ # 텍스트 패딩
35
+ curr_len = len(item['input_ids'])
36
+ padding_len = max_text_length - curr_len
37
+
38
+ padded_ids = torch.cat([
39
+ item['input_ids'],
40
+ torch.zeros(padding_len, dtype=torch.long)
41
+ ])
42
+ padded_mask = torch.cat([
43
+ item['attention_mask'],
44
+ torch.zeros(padding_len, dtype=torch.long)
45
+ ])
46
+
47
+ input_ids.append(padded_ids)
48
+ attention_masks.append(padded_mask)
49
+ image_paths.append(item['image_path'])
50
+ captions.append(item['caption'])
51
+
52
+ # 배치로 만들기
53
+ return {
54
+ 'pixel_values': torch.stack(pixel_values),
55
+ 'input_ids': torch.stack(input_ids),
56
+ 'attention_mask': torch.stack(attention_masks),
57
+ 'image_path': image_paths,
58
+ 'caption': captions
59
+ }
60
+
61
+ class JsonGalleryDataset(Dataset):
62
+ def __init__(self, json_path, processor, max_length=248):
63
+ with open(json_path, 'r', encoding='utf-8') as f:
64
+ self.data = json.load(f)
65
+ self.processor = processor
66
+ self.max_length = max_length
67
+
68
+ def __len__(self):
69
+ return len(self.data)
70
+
71
+ def __getitem__(self, idx):
72
+ item = self.data[idx]
73
+ image_path = item['filename']
74
+ caption = item['caption']
75
+
76
+ try:
77
+ image = Image.open(image_path).convert("RGB")
78
+ # 이미지와 텍스트 모두 처리
79
+ processed_image = self.processor(images=image, return_tensors="pt")
80
+ processed_text = self.processor(
81
+ text=caption,
82
+ return_tensors="pt",
83
+ padding='max_length',
84
+ max_length=self.max_length,
85
+ truncation=True
86
+ )
87
+
88
+ return {
89
+ 'pixel_values': processed_image['pixel_values'].squeeze(0),
90
+ 'input_ids': processed_text['input_ids'].squeeze(0),
91
+ 'attention_mask': processed_text['attention_mask'].squeeze(0),
92
+ 'image_path': image_path,
93
+ 'caption': caption
94
+ }
95
+ except Exception as e:
96
+ print(f"Error loading image {image_path}: {e}")
97
+ return self.__getitem__(0)
98
+
99
+ def compute_similarities(model, query_features, gallery_features):
100
+ query_features = F.normalize(query_features, dim=-1)
101
+ gallery_features = F.normalize(gallery_features, dim=-1)
102
+ similarities = torch.mm(query_features, gallery_features.t())
103
+ return similarities
104
+
105
+ def process_query(model, processor, query, device, is_image=False, max_length=248):
106
+ with torch.no_grad():
107
+ if is_image:
108
+ image = Image.open(query).convert("RGB")
109
+ inputs = processor(images=image, return_tensors="pt")
110
+ pixel_values = inputs['pixel_values'].to(device)
111
+ features = model.get_image_features(pixel_values)
112
+ else:
113
+ inputs = processor(
114
+ text=query,
115
+ return_tensors="pt",
116
+ padding='max_length',
117
+ max_length=max_length,
118
+ truncation=True
119
+ )
120
+ input_ids = inputs['input_ids'].to(device)
121
+ attention_mask = inputs['attention_mask'].to(device)
122
+ features = model.get_text_features(input_ids, attention_mask)
123
+ return features
124
+
125
+ def visualize_results(query, results, output_path, is_image_query=False):
126
+ # 하나의 결과만 사용
127
+ n_results = 1
128
+ plt.rcParams['figure.facecolor'] = 'white'
129
+ plt.rcParams['axes.facecolor'] = 'white'
130
+ fig, axes = plt.subplots(1, 1 + n_results, figsize=(10, 4)) # figure 크기 조정
131
+
132
+ if is_image_query:
133
+ # 이미지 쿼리인 경우
134
+ query_img = Image.open(query)
135
+ axes[0].imshow(query_img)
136
+ axes[0].set_title("Image query", fontsize=12, pad=10)
137
+
138
+ # 결과 표시 (캡션만)
139
+ img_path, similarity, caption = results[0] # 첫 번째 결과만 사용
140
+ axes[1].set_facecolor('white')
141
+ wrapped_text = textwrap.fill(caption, width=40)
142
+
143
+ # 텍스트 박스 없이 텍스트만 표시
144
+ axes[1].text(0.5, 0.5, wrapped_text,
145
+ horizontalalignment='center',
146
+ verticalalignment='center',
147
+ transform=axes[1].transAxes,
148
+ wrap=True,
149
+ fontsize=10)
150
+ # axes[1].set_title("Recall@1", fontsize=12, pad=10)
151
+
152
+ else:
153
+ # 텍스트 쿼리인 경우
154
+ axes[0].set_facecolor('white')
155
+ wrapped_text = textwrap.fill(query, width=40)
156
+
157
+ # 텍스트 박스 없이 텍스트만 표시
158
+ axes[0].text(0.5, 0.5, wrapped_text,
159
+ horizontalalignment='center',
160
+ verticalalignment='center',
161
+ transform=axes[0].transAxes,
162
+ wrap=True,
163
+ fontsize=10)
164
+ axes[0].set_title("Text query", fontsize=12, pad=10)
165
+
166
+ # 결과 표시 (이미지만)
167
+ img_path, similarity, caption = results[0] # 첫 번째 결과만 사용
168
+ img = Image.open(img_path)
169
+ axes[1].imshow(img)
170
+ # axes[1].set_title("Recall@1", fontsize=12, pad=10)
171
+
172
+ # 모든 subplot의 공통 스타일 설정
173
+ for ax in axes:
174
+ ax.axis('off')
175
+ ax.set_xticks([])
176
+ ax.set_yticks([])
177
+ ax.spines['top'].set_visible(True)
178
+ ax.spines['right'].set_visible(True)
179
+ ax.spines['bottom'].set_visible(True)
180
+ ax.spines['left'].set_visible(True)
181
+
182
+ plt.tight_layout()
183
+ plt.subplots_adjust(wspace=0)
184
+
185
+ plt.savefig(output_path,
186
+ dpi=300,
187
+ bbox_inches='tight',
188
+ facecolor='white',
189
+ edgecolor='none',
190
+ pad_inches=0.2)
191
+ plt.close()
192
+
193
+ def main(args):
194
+ fabric = L.Fabric(
195
+ accelerator="cuda",
196
+ devices=1,
197
+ precision="bf16-mixed"
198
+ )
199
+ fabric.launch()
200
+ fabric.seed_everything(1337 + fabric.global_rank)
201
+
202
+ # 모델 설정
203
+ if args.model == 'L-336':
204
+ model_name = 'openai/clip-vit-large-patch14-336'
205
+ elif args.model == 'L':
206
+ model_name = 'openai/clip-vit-large-patch14'
207
+ elif args.model == 'B':
208
+ model_name = 'openai/clip-vit-base-patch32'
209
+ elif args.model == 'G':
210
+ model_name = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
211
+
212
+ # 모델과 프로세서 로드
213
+ processor = transformers.CLIPProcessor.from_pretrained(model_name)
214
+ model = transformers.CLIPModel.from_pretrained(model_name).bfloat16()
215
+
216
+ # 먼저 position embedding 확장
217
+ longclip_pos_embeddings(model, args.new_max_token)
218
+
219
+ # 그 다음 가중치 로드
220
+ if args.ckpt:
221
+ state_dict = torch.load(args.ckpt, weights_only=True)
222
+ model.load_state_dict(state_dict, strict=False)
223
+
224
+ model = model.to(fabric.device)
225
+ model.eval()
226
+
227
+ # 갤러리 데이터셋 준비
228
+ dataset = JsonGalleryDataset(args.gallery_json, processor, args.new_max_token)
229
+ dataloader = DataLoader(
230
+ dataset,
231
+ batch_size=32,
232
+ shuffle=False,
233
+ num_workers=4,
234
+ collate_fn=collate_fn
235
+ )
236
+
237
+ # 쿼리가 이미지인지 텍스트인지 확인
238
+ is_image_query = args.query.endswith(('.jpg', '.jpeg', '.png'))
239
+
240
+ # 갤러리 특징 추출
241
+ gallery_features = []
242
+ gallery_items = []
243
+
244
+ with torch.no_grad():
245
+ for batch in dataloader:
246
+ if is_image_query:
247
+ # 이미지 쿼리의 경우 텍스트 특징 추출
248
+ input_ids = batch['input_ids'].to(fabric.device)
249
+ attention_mask = batch['attention_mask'].to(fabric.device)
250
+ features = model.get_text_features(input_ids, attention_mask)
251
+ else:
252
+ # 텍스트 쿼리의 경우 이미지 특징 추출
253
+ pixel_values = batch['pixel_values'].to(fabric.device)
254
+ features = model.get_image_features(pixel_values)
255
+
256
+ gallery_features.append(features)
257
+
258
+ for path, caption in zip(batch['image_path'], batch['caption']):
259
+ gallery_items.append({
260
+ 'image_path': path,
261
+ 'caption': caption
262
+ })
263
+
264
+ gallery_features = torch.cat(gallery_features, dim=0)
265
+
266
+ # 쿼리 특징 추출
267
+ query_features = process_query(model, processor, args.query, fabric.device, is_image_query, args.new_max_token)
268
+
269
+ # 유사도 계산 및 상위 1개 결과만 가져오기
270
+ similarities = compute_similarities(model, query_features, gallery_features)
271
+ top_k = 1 # 하나의 결과만 가져오도록 변경
272
+ top_k_similarities, top_k_indices = torch.topk(similarities[0].float(), k=top_k)
273
+
274
+ # 결과 정리
275
+ results = []
276
+ top_k_similarities_np = top_k_similarities.cpu().numpy()
277
+ top_k_indices_np = top_k_indices.cpu().numpy()
278
+
279
+ for sim, idx in zip(top_k_similarities_np, top_k_indices_np):
280
+ item = gallery_items[idx]
281
+ results.append((
282
+ item['image_path'],
283
+ sim,
284
+ item['caption']
285
+ ))
286
+
287
+ # 결과 시각화
288
+ visualize_results(
289
+ args.query,
290
+ results,
291
+ args.output_path,
292
+ is_image_query
293
+ )
294
+
295
+ if __name__ == "__main__":
296
+ torch.set_float32_matmul_precision("high")
297
+
298
+ parser = argparse.ArgumentParser()
299
+ parser.add_argument("--query", type=str, required=True,
300
+ help="Path to query image or text query")
301
+ parser.add_argument("--gallery_json", type=str, required=True,
302
+ help="Path to JSON file containing gallery information")
303
+ parser.add_argument("--output_path", type=str, required=True,
304
+ help="Path to save visualization")
305
+ parser.add_argument("--model", type=str, default='L',
306
+ choices=['L-336', 'L', 'B', 'G'])
307
+ parser.add_argument("--ckpt", type=str, default='',
308
+ help="Path to custom checkpoint")
309
+ parser.add_argument("--new_max_token", type=int, default=248,
310
+ help="Maximum number of text tokens")
311
+
312
+ args = parser.parse_args()
313
+ main(args)