Upload 29 files
Browse files- .gitattributes +1 -0
- GOAL_github/LISM/DCIsamwithclipmaxsimmaintainbackrefsentencefortestset.py +170 -0
- GOAL_github/LISM/DCIsamwithclipmaxsimmaintainbackrefsentencefortrainset.py +170 -0
- GOAL_github/datasets/DCI_segment_only_sim_max_del_org.json +0 -0
- GOAL_github/datasets/DCI_test.json +0 -0
- GOAL_github/datasets/DCI_test_joint_sim_max_1 +0 -0
- GOAL_github/datasets/DCI_train_del_org.json +0 -0
- GOAL_github/datasets/docci_segment_sim_bbox_del_org.json +3 -0
- GOAL_github/datasets/docci_test.json +0 -0
- GOAL_github/datasets/docci_test_joint_sim_max_1 +0 -0
- GOAL_github/datasets/docci_train_del_org.json +0 -0
- GOAL_github/datasets/urban_dataset_test.json +0 -0
- GOAL_github/goal.py +469 -0
- GOAL_github/mAP_goal_jointtest.py +256 -0
- GOAL_github/retrieval_goal.py +171 -0
- GOAL_github/utils/__pycache__/easydict.cpython-39.pyc +0 -0
- GOAL_github/utils/__pycache__/func.cpython-310.pyc +0 -0
- GOAL_github/utils/__pycache__/func.cpython-311.pyc +0 -0
- GOAL_github/utils/__pycache__/func.cpython-39.pyc +0 -0
- GOAL_github/utils/__pycache__/randaugment.cpython-310.pyc +0 -0
- GOAL_github/utils/__pycache__/randaugment.cpython-311.pyc +0 -0
- GOAL_github/utils/__pycache__/randaugment.cpython-39.pyc +0 -0
- GOAL_github/utils/__pycache__/transforms.cpython-310.pyc +0 -0
- GOAL_github/utils/__pycache__/transforms.cpython-311.pyc +0 -0
- GOAL_github/utils/__pycache__/transforms.cpython-39.pyc +0 -0
- GOAL_github/utils/func.py +106 -0
- GOAL_github/utils/randaugment.py +349 -0
- GOAL_github/utils/transforms.py +130 -0
- GOAL_github/visualization/visualization_attentionmap_longtestset.py +188 -0
- GOAL_github/visualization/visualization_retreival.py +313 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
GOAL_github/datasets/docci_segment_sim_bbox_del_org.json filter=lfs diff=lfs merge=lfs -text
|
GOAL_github/LISM/DCIsamwithclipmaxsimmaintainbackrefsentencefortestset.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
from io import BytesIO
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import numpy as np
|
| 5 |
+
from transformers import pipeline, CLIPProcessor, CLIPModel
|
| 6 |
+
import cv2
|
| 7 |
+
import torch
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
# SAM 모델 로드
|
| 13 |
+
generator = pipeline("mask-generation", device=2, points_per_batch=128)
|
| 14 |
+
|
| 15 |
+
# CLIP 모델 로드
|
| 16 |
+
model_name = "openai/clip-vit-large-patch14-336"
|
| 17 |
+
clip_model = CLIPModel.from_pretrained(model_name)
|
| 18 |
+
clip_processor = CLIPProcessor.from_pretrained(model_name)
|
| 19 |
+
|
| 20 |
+
# 원본 배경 유지
|
| 21 |
+
def keep_original_background(image, mask):
|
| 22 |
+
masked_image = np.array(image)
|
| 23 |
+
mask_3d = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
|
| 24 |
+
masked_image = masked_image * mask_3d + (1 - mask_3d) * np.array(image)
|
| 25 |
+
return Image.fromarray(np.uint8(masked_image))
|
| 26 |
+
|
| 27 |
+
# 이미지 리사이징 함수
|
| 28 |
+
def resize_image_if_needed(image_path, max_size=(2048, 1536)):
|
| 29 |
+
with Image.open(image_path) as img:
|
| 30 |
+
if img.size[0] > 4000 or img.size[1] > 3000:
|
| 31 |
+
img.thumbnail(max_size, Image.LANCZOS)
|
| 32 |
+
return img.copy()
|
| 33 |
+
return img.copy()
|
| 34 |
+
|
| 35 |
+
# CLIP을 사용한 이미지-캡션 매칭
|
| 36 |
+
def get_clip_similarity(images, texts):
|
| 37 |
+
try:
|
| 38 |
+
inputs = clip_processor(text=texts, images=images, return_tensors="pt", padding=True)
|
| 39 |
+
outputs = clip_model(**inputs)
|
| 40 |
+
return outputs.logits_per_image.cpu().detach().numpy()
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"Error in CLIP processing: {str(e)}")
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
# 간단한 문장 토큰화
|
| 46 |
+
def tokenize_caption(text):
|
| 47 |
+
return [s.strip() for s in text.split('.') if s.strip()]
|
| 48 |
+
|
| 49 |
+
# 기본 경로 설정
|
| 50 |
+
matching_results_path = "matching_results"
|
| 51 |
+
output_base_dir = "segment_with_background_DCI_test_set_max_0.01"
|
| 52 |
+
os.makedirs(output_base_dir, exist_ok=True)
|
| 53 |
+
|
| 54 |
+
# train_sa로 시작하는 폴더들 찾기
|
| 55 |
+
train_folders = [f for f in os.listdir(matching_results_path) if f.startswith('test_sa_')]
|
| 56 |
+
|
| 57 |
+
# 각 train 폴더에 대해 처리
|
| 58 |
+
for train_folder in tqdm(train_folders, desc="Processing folders"):
|
| 59 |
+
folder_path = os.path.join(matching_results_path, train_folder)
|
| 60 |
+
|
| 61 |
+
# JSON 파일 찾기 (sa_xxxxxx_result.json)
|
| 62 |
+
json_files = [f for f in os.listdir(folder_path) if f.endswith('_result.json')]
|
| 63 |
+
if not json_files:
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
json_path = os.path.join(folder_path, json_files[0])
|
| 67 |
+
|
| 68 |
+
# JSON 파일 로드
|
| 69 |
+
try:
|
| 70 |
+
with open(json_path, 'r') as f:
|
| 71 |
+
annotation = json.load(f)
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(f"Error loading JSON for {train_folder}: {str(e)}")
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
image_filename = annotation['original_image']
|
| 77 |
+
image_path = os.path.join(folder_path, image_filename)
|
| 78 |
+
caption = annotation['extra_caption']
|
| 79 |
+
|
| 80 |
+
# 결과를 저장할 디렉토리 설정
|
| 81 |
+
output_dir = os.path.join(output_base_dir, f"{image_filename}_results")
|
| 82 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
# 이미 처리된 이미지는 건너뛰기
|
| 85 |
+
if os.path.exists(os.path.join(output_dir, 'matched_dataset_test.json')):
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
# 이미지 로드 및 세그멘테이션
|
| 89 |
+
try:
|
| 90 |
+
original_image = Image.open(image_path)
|
| 91 |
+
resized_image = resize_image_if_needed(image_path)
|
| 92 |
+
|
| 93 |
+
# 리사이즈된 이미지로 SAM 실행
|
| 94 |
+
outputs = generator(resized_image, points_per_batch=128)
|
| 95 |
+
|
| 96 |
+
# 원본 이미지 크기로 마스크 리사이즈
|
| 97 |
+
if original_image.size != resized_image.size:
|
| 98 |
+
for i in range(len(outputs['masks'])):
|
| 99 |
+
mask = Image.fromarray(outputs['masks'][i])
|
| 100 |
+
mask = mask.resize(original_image.size, Image.NEAREST)
|
| 101 |
+
outputs['masks'][i] = np.array(mask)
|
| 102 |
+
|
| 103 |
+
image = original_image
|
| 104 |
+
except Exception as e:
|
| 105 |
+
print(f"Error processing {image_filename}: {str(e)}")
|
| 106 |
+
continue
|
| 107 |
+
|
| 108 |
+
# 필터링 및 세그먼트 이미지 생성
|
| 109 |
+
min_area_ratio = 0.01
|
| 110 |
+
max_area_ratio = 0.8
|
| 111 |
+
total_area = image.size[0] * image.size[1]
|
| 112 |
+
segmented_images = [image]
|
| 113 |
+
segmented_image_paths = ["original_image.jpg"]
|
| 114 |
+
|
| 115 |
+
for i, mask in enumerate(outputs['masks']):
|
| 116 |
+
segment_area = np.sum(mask)
|
| 117 |
+
area_ratio = segment_area / total_area
|
| 118 |
+
if min_area_ratio <= area_ratio <= max_area_ratio:
|
| 119 |
+
masked_image = keep_original_background(image, mask)
|
| 120 |
+
|
| 121 |
+
y, x = np.where(mask)
|
| 122 |
+
y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
|
| 123 |
+
cropped_image = np.array(masked_image)[y_min:y_max+1, x_min:x_max+1]
|
| 124 |
+
|
| 125 |
+
segmented_images.append(Image.fromarray(np.uint8(cropped_image)))
|
| 126 |
+
segmented_image_paths.append(f"{image_filename.split('.')[0]}_max_{i+1}.png")
|
| 127 |
+
|
| 128 |
+
# 캡션 토큰화
|
| 129 |
+
tokenized_captions = tokenize_caption(caption)
|
| 130 |
+
|
| 131 |
+
# 모든 문장-이미지 쌍에 대한 유사도 계산
|
| 132 |
+
similarity_matrix = get_clip_similarity(segmented_images, tokenized_captions)
|
| 133 |
+
|
| 134 |
+
if similarity_matrix is None:
|
| 135 |
+
print(f"Skipping image {image_filename} due to CLIP processing error")
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
matched_data = []
|
| 139 |
+
saved_images = set()
|
| 140 |
+
|
| 141 |
+
# 모든 문장에 대해 가장 높은 유사도를 가진 이미지와 매칭
|
| 142 |
+
for j, caption in enumerate(tokenized_captions):
|
| 143 |
+
i = np.argmax(similarity_matrix[:, j])
|
| 144 |
+
similarity_score = float(similarity_matrix[i, j])
|
| 145 |
+
matched_image_path = segmented_image_paths[i]
|
| 146 |
+
|
| 147 |
+
matched_data.append({
|
| 148 |
+
"caption": caption,
|
| 149 |
+
"matched_image_path": matched_image_path,
|
| 150 |
+
"similarity_score": similarity_score
|
| 151 |
+
})
|
| 152 |
+
|
| 153 |
+
saved_images.add(i)
|
| 154 |
+
|
| 155 |
+
# 원본 이미지 저장
|
| 156 |
+
original_image_path = os.path.join(output_dir, "original_image.jpg")
|
| 157 |
+
image.save(original_image_path)
|
| 158 |
+
|
| 159 |
+
# 매칭된 세그먼트 이미지 저장
|
| 160 |
+
for i in saved_images:
|
| 161 |
+
if i != 0: # 원본 이미지는 건너뛰기
|
| 162 |
+
output_file_path = os.path.join(output_dir, segmented_image_paths[i])
|
| 163 |
+
segmented_images[i].save(output_file_path)
|
| 164 |
+
|
| 165 |
+
# 결과 저장
|
| 166 |
+
json_output_path = os.path.join(output_dir, 'matched_dataset_test.json')
|
| 167 |
+
with open(json_output_path, 'w') as f:
|
| 168 |
+
json.dump(matched_data, f, indent=2)
|
| 169 |
+
|
| 170 |
+
print("모든 test set 이미지 처리가 완료되었습니다.")
|
GOAL_github/LISM/DCIsamwithclipmaxsimmaintainbackrefsentencefortrainset.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
from io import BytesIO
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import numpy as np
|
| 5 |
+
from transformers import pipeline, CLIPProcessor, CLIPModel
|
| 6 |
+
import cv2
|
| 7 |
+
import torch
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
# SAM 모델 로드
|
| 13 |
+
generator = pipeline("mask-generation", device=1, points_per_batch=128)
|
| 14 |
+
|
| 15 |
+
# CLIP 모델 로드
|
| 16 |
+
model_name = "openai/clip-vit-large-patch14-336"
|
| 17 |
+
clip_model = CLIPModel.from_pretrained(model_name)
|
| 18 |
+
clip_processor = CLIPProcessor.from_pretrained(model_name)
|
| 19 |
+
|
| 20 |
+
# 원본 배경 유지
|
| 21 |
+
def keep_original_background(image, mask):
|
| 22 |
+
masked_image = np.array(image)
|
| 23 |
+
mask_3d = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
|
| 24 |
+
masked_image = masked_image * mask_3d + (1 - mask_3d) * np.array(image)
|
| 25 |
+
return Image.fromarray(np.uint8(masked_image))
|
| 26 |
+
|
| 27 |
+
# 이미지 리사이징 함수
|
| 28 |
+
def resize_image_if_needed(image_path, max_size=(2048, 1536)):
|
| 29 |
+
with Image.open(image_path) as img:
|
| 30 |
+
if img.size[0] > 4000 or img.size[1] > 3000:
|
| 31 |
+
img.thumbnail(max_size, Image.LANCZOS)
|
| 32 |
+
return img.copy()
|
| 33 |
+
return img.copy()
|
| 34 |
+
|
| 35 |
+
# CLIP을 사용한 이미지-캡션 매칭
|
| 36 |
+
def get_clip_similarity(images, texts):
|
| 37 |
+
try:
|
| 38 |
+
inputs = clip_processor(text=texts, images=images, return_tensors="pt", padding=True)
|
| 39 |
+
outputs = clip_model(**inputs)
|
| 40 |
+
return outputs.logits_per_image.cpu().detach().numpy()
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"Error in CLIP processing: {str(e)}")
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
# 간단한 문장 토큰화
|
| 46 |
+
def tokenize_caption(text):
|
| 47 |
+
return [s.strip() for s in text.split('.') if s.strip()]
|
| 48 |
+
|
| 49 |
+
# 기본 경로 설정
|
| 50 |
+
matching_results_path = "matching_results"
|
| 51 |
+
output_base_dir = "segment_with_background_DCI_train_set_max_0.01"
|
| 52 |
+
os.makedirs(output_base_dir, exist_ok=True)
|
| 53 |
+
|
| 54 |
+
# train_sa로 시작하는 폴더들 찾기
|
| 55 |
+
train_folders = [f for f in os.listdir(matching_results_path) if f.startswith('train_sa_')]
|
| 56 |
+
|
| 57 |
+
# 각 train 폴더에 대해 처리
|
| 58 |
+
for train_folder in tqdm(train_folders, desc="Processing folders"):
|
| 59 |
+
folder_path = os.path.join(matching_results_path, train_folder)
|
| 60 |
+
|
| 61 |
+
# JSON 파일 찾기 (sa_xxxxxx_result.json)
|
| 62 |
+
json_files = [f for f in os.listdir(folder_path) if f.endswith('_result.json')]
|
| 63 |
+
if not json_files:
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
json_path = os.path.join(folder_path, json_files[0])
|
| 67 |
+
|
| 68 |
+
# JSON 파일 로드
|
| 69 |
+
try:
|
| 70 |
+
with open(json_path, 'r') as f:
|
| 71 |
+
annotation = json.load(f)
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(f"Error loading JSON for {train_folder}: {str(e)}")
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
image_filename = annotation['original_image']
|
| 77 |
+
image_path = os.path.join(folder_path, image_filename)
|
| 78 |
+
caption = annotation['extra_caption']
|
| 79 |
+
|
| 80 |
+
# 결과를 저장할 디렉토리 설정
|
| 81 |
+
output_dir = os.path.join(output_base_dir, f"{image_filename}_results")
|
| 82 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
# 이미 처리된 이미지는 건너뛰기
|
| 85 |
+
if os.path.exists(os.path.join(output_dir, 'matched_dataset_train.json')):
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
# 이미지 로드 및 세그멘테이션
|
| 89 |
+
try:
|
| 90 |
+
original_image = Image.open(image_path)
|
| 91 |
+
resized_image = resize_image_if_needed(image_path)
|
| 92 |
+
|
| 93 |
+
# 리사이즈된 이미지로 SAM 실행
|
| 94 |
+
outputs = generator(resized_image, points_per_batch=128)
|
| 95 |
+
|
| 96 |
+
# 원본 이미지 크기로 마스크 리사이즈
|
| 97 |
+
if original_image.size != resized_image.size:
|
| 98 |
+
for i in range(len(outputs['masks'])):
|
| 99 |
+
mask = Image.fromarray(outputs['masks'][i])
|
| 100 |
+
mask = mask.resize(original_image.size, Image.NEAREST)
|
| 101 |
+
outputs['masks'][i] = np.array(mask)
|
| 102 |
+
|
| 103 |
+
image = original_image
|
| 104 |
+
except Exception as e:
|
| 105 |
+
print(f"Error processing {image_filename}: {str(e)}")
|
| 106 |
+
continue
|
| 107 |
+
|
| 108 |
+
# 필터링 및 세그먼트 이미지 생성
|
| 109 |
+
min_area_ratio = 0.01
|
| 110 |
+
max_area_ratio = 0.8
|
| 111 |
+
total_area = image.size[0] * image.size[1]
|
| 112 |
+
segmented_images = [image]
|
| 113 |
+
segmented_image_paths = ["original_image.jpg"]
|
| 114 |
+
|
| 115 |
+
for i, mask in enumerate(outputs['masks']):
|
| 116 |
+
segment_area = np.sum(mask)
|
| 117 |
+
area_ratio = segment_area / total_area
|
| 118 |
+
if min_area_ratio <= area_ratio <= max_area_ratio:
|
| 119 |
+
masked_image = keep_original_background(image, mask)
|
| 120 |
+
|
| 121 |
+
y, x = np.where(mask)
|
| 122 |
+
y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
|
| 123 |
+
cropped_image = np.array(masked_image)[y_min:y_max+1, x_min:x_max+1]
|
| 124 |
+
|
| 125 |
+
segmented_images.append(Image.fromarray(np.uint8(cropped_image)))
|
| 126 |
+
segmented_image_paths.append(f"{image_filename.split('.')[0]}_max_{i+1}.png")
|
| 127 |
+
|
| 128 |
+
# 캡션 토큰화
|
| 129 |
+
tokenized_captions = tokenize_caption(caption)
|
| 130 |
+
|
| 131 |
+
# 모든 문장-이미지 쌍에 대한 유사도 계산
|
| 132 |
+
similarity_matrix = get_clip_similarity(segmented_images, tokenized_captions)
|
| 133 |
+
|
| 134 |
+
if similarity_matrix is None:
|
| 135 |
+
print(f"Skipping image {image_filename} due to CLIP processing error")
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
matched_data = []
|
| 139 |
+
saved_images = set()
|
| 140 |
+
|
| 141 |
+
# 모든 문장에 대해 가장 높은 유사도를 가진 이미지와 매칭
|
| 142 |
+
for j, caption in enumerate(tokenized_captions):
|
| 143 |
+
i = np.argmax(similarity_matrix[:, j])
|
| 144 |
+
similarity_score = float(similarity_matrix[i, j])
|
| 145 |
+
matched_image_path = segmented_image_paths[i]
|
| 146 |
+
|
| 147 |
+
matched_data.append({
|
| 148 |
+
"caption": caption,
|
| 149 |
+
"matched_image_path": matched_image_path,
|
| 150 |
+
"similarity_score": similarity_score
|
| 151 |
+
})
|
| 152 |
+
|
| 153 |
+
saved_images.add(i)
|
| 154 |
+
|
| 155 |
+
# 원본 이미지 저장
|
| 156 |
+
original_image_path = os.path.join(output_dir, "original_image.jpg")
|
| 157 |
+
image.save(original_image_path)
|
| 158 |
+
|
| 159 |
+
# 매칭된 세그먼트 이미지 저장
|
| 160 |
+
for i in saved_images:
|
| 161 |
+
if i != 0: # 원본 이미지는 건너뛰기
|
| 162 |
+
output_file_path = os.path.join(output_dir, segmented_image_paths[i])
|
| 163 |
+
segmented_images[i].save(output_file_path)
|
| 164 |
+
|
| 165 |
+
# 결과 저장
|
| 166 |
+
json_output_path = os.path.join(output_dir, 'matched_dataset_train.json')
|
| 167 |
+
with open(json_output_path, 'w') as f:
|
| 168 |
+
json.dump(matched_data, f, indent=2)
|
| 169 |
+
|
| 170 |
+
print("모든 train set 이미지 처리가 완료되었습니다.")
|
GOAL_github/datasets/DCI_segment_only_sim_max_del_org.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
GOAL_github/datasets/DCI_test.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
GOAL_github/datasets/DCI_test_joint_sim_max_1
ADDED
|
File without changes
|
GOAL_github/datasets/DCI_train_del_org.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
GOAL_github/datasets/docci_segment_sim_bbox_del_org.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6ce1febc6eca26cb77936fdfb2f4fc6aa8c30eb52172505e22a015d06757a7bc
|
| 3 |
+
size 35738701
|
GOAL_github/datasets/docci_test.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
GOAL_github/datasets/docci_test_joint_sim_max_1
ADDED
|
File without changes
|
GOAL_github/datasets/docci_train_del_org.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
GOAL_github/datasets/urban_dataset_test.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
GOAL_github/goal.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
| 4 |
+
import json
|
| 5 |
+
import argparse
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
import lightning as L
|
| 10 |
+
import transformers
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import shutil
|
| 13 |
+
import time
|
| 14 |
+
import numpy as np
|
| 15 |
+
from utils.func import *
|
| 16 |
+
from utils.transforms import *
|
| 17 |
+
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
| 18 |
+
import shutil
|
| 19 |
+
import math
|
| 20 |
+
import random
|
| 21 |
+
import wandb
|
| 22 |
+
|
| 23 |
+
torch.autograd.set_detect_anomaly(True) # Enable anomaly detection
|
| 24 |
+
|
| 25 |
+
def clip_loss(sim):
|
| 26 |
+
gt = torch.arange(len(sim), dtype=torch.long, device=sim.device)
|
| 27 |
+
return (torch.nn.CrossEntropyLoss()(sim, gt) + torch.nn.CrossEntropyLoss()(sim.t(), gt)) / 2.0
|
| 28 |
+
|
| 29 |
+
def get_patch_tokens_from_bbox(patch_tokens, bbox, b, original_image_size, image_size=224, patch_size=16):
|
| 30 |
+
# Get original dimensions from actual image size
|
| 31 |
+
org_width, org_height = original_image_size
|
| 32 |
+
|
| 33 |
+
# Scale coordinates to image_size
|
| 34 |
+
x1 = int(round(bbox['x1'][b].item() * image_size / org_width))
|
| 35 |
+
y1 = int(round(bbox['y1'][b].item() * image_size / org_height))
|
| 36 |
+
x2 = int(round(bbox['x2'][b].item() * image_size / org_width))
|
| 37 |
+
y2 = int(round(bbox['y2'][b].item() * image_size / org_height))
|
| 38 |
+
|
| 39 |
+
# Ensure coordinates are within image bounds
|
| 40 |
+
x1 = max(0, min(x1, image_size-1))
|
| 41 |
+
y1 = max(0, min(y1, image_size-1))
|
| 42 |
+
x2 = max(0, min(x2, image_size))
|
| 43 |
+
y2 = max(0, min(y2, image_size))
|
| 44 |
+
|
| 45 |
+
# Convert to patch indices (include any patch that the bbox touches)
|
| 46 |
+
patch_x1 = x1 // patch_size
|
| 47 |
+
patch_y1 = y1 // patch_size
|
| 48 |
+
patch_x2 = (x2 + patch_size - 1) // patch_size
|
| 49 |
+
patch_y2 = (y2 + patch_size - 1) // patch_size
|
| 50 |
+
|
| 51 |
+
# Get indices of patches
|
| 52 |
+
num_patches = (image_size // patch_size)
|
| 53 |
+
indices = []
|
| 54 |
+
for i in range(patch_y1, patch_y2):
|
| 55 |
+
for j in range(patch_x1, patch_x2):
|
| 56 |
+
indices.append(i * num_patches + j + 1)
|
| 57 |
+
|
| 58 |
+
# Extract and pool relevant patch tokens
|
| 59 |
+
relevant_tokens = patch_tokens[:, indices, :]
|
| 60 |
+
pooled_tokens = torch.mean(relevant_tokens, dim=1)
|
| 61 |
+
|
| 62 |
+
return pooled_tokens
|
| 63 |
+
|
| 64 |
+
def get_text_tokens_from_segment(text_tokens, org_text, seg_text, processor):
|
| 65 |
+
"""
|
| 66 |
+
Args:
|
| 67 |
+
text_tokens: (B, L, D) tensor of text tokens - all tokens of original text
|
| 68 |
+
org_text: original text string
|
| 69 |
+
seg_text: segment text string
|
| 70 |
+
processor: CLIP processor
|
| 71 |
+
Returns:
|
| 72 |
+
pooled_tokens: (B, D) tensor of pooled text tokens from the relevant segment
|
| 73 |
+
"""
|
| 74 |
+
# Text preprocessing
|
| 75 |
+
org_text = ' '.join(org_text.split()).strip()
|
| 76 |
+
seg_text = ' '.join(seg_text.split()).strip()
|
| 77 |
+
|
| 78 |
+
# Split org_text into sentences
|
| 79 |
+
sentences = org_text.split('.')
|
| 80 |
+
sentences = [s.strip() for s in sentences if s.strip()]
|
| 81 |
+
|
| 82 |
+
# Find seg_text position
|
| 83 |
+
seg_pos = org_text.find(seg_text)
|
| 84 |
+
current_pos = 0
|
| 85 |
+
sent_idx = -1
|
| 86 |
+
|
| 87 |
+
# Find position by sentence
|
| 88 |
+
for i, sent in enumerate(sentences):
|
| 89 |
+
sent = sent.strip()
|
| 90 |
+
if sent == seg_text:
|
| 91 |
+
seg_pos = current_pos
|
| 92 |
+
sent_idx = i
|
| 93 |
+
break
|
| 94 |
+
current_pos += len(sent) + 2
|
| 95 |
+
|
| 96 |
+
assert seg_pos != -1, f"Segment text not found in original text"
|
| 97 |
+
|
| 98 |
+
# Tokenize segment text
|
| 99 |
+
seg_tokens = processor(text=seg_text,
|
| 100 |
+
return_tensors="pt",
|
| 101 |
+
padding=False,
|
| 102 |
+
truncation=False)
|
| 103 |
+
seg_token_length = len(seg_tokens.input_ids[0]) - 2 # Exclude CLS, EOS tokens
|
| 104 |
+
|
| 105 |
+
if sent_idx != -1:
|
| 106 |
+
# Calculate token index based on sentence position
|
| 107 |
+
text_before = '. '.join(sentences[:sent_idx]) + ('. ' if sent_idx > 0 else '')
|
| 108 |
+
tokens_before = processor(text=text_before,
|
| 109 |
+
return_tensors="pt",
|
| 110 |
+
padding=False,
|
| 111 |
+
truncation=False)
|
| 112 |
+
start_idx = len(tokens_before.input_ids[0])
|
| 113 |
+
else:
|
| 114 |
+
# Calculate token index based on string position
|
| 115 |
+
text_before = org_text[:seg_pos]
|
| 116 |
+
tokens_before = processor(text=text_before,
|
| 117 |
+
return_tensors="pt",
|
| 118 |
+
padding=False,
|
| 119 |
+
truncation=False)
|
| 120 |
+
start_idx = len(tokens_before.input_ids[0])
|
| 121 |
+
|
| 122 |
+
# Adjust range considering maximum token length
|
| 123 |
+
max_length = text_tokens.shape[1] # 248
|
| 124 |
+
if start_idx >= max_length:
|
| 125 |
+
# If segment is at a position beyond max length,
|
| 126 |
+
# extract tokens from the end, securing space equal to segment length
|
| 127 |
+
end_idx = max_length - 1
|
| 128 |
+
start_idx = max(1, end_idx - seg_token_length) # Start from after CLS token (1)
|
| 129 |
+
else:
|
| 130 |
+
# If within normal range
|
| 131 |
+
end_idx = min(start_idx + seg_token_length, max_length - 1)
|
| 132 |
+
|
| 133 |
+
# Extract tokens
|
| 134 |
+
relevant_tokens = text_tokens[:, start_idx:end_idx, :]
|
| 135 |
+
|
| 136 |
+
# Handle case when no tokens are extracted
|
| 137 |
+
if relevant_tokens.shape[1] == 0:
|
| 138 |
+
# Fallback: use tokens from the beginning
|
| 139 |
+
relevant_tokens = text_tokens[:, 1:min(1 + seg_token_length, max_length), :]
|
| 140 |
+
|
| 141 |
+
# Pool tokens
|
| 142 |
+
pooled_tokens = torch.mean(relevant_tokens, dim=1)
|
| 143 |
+
|
| 144 |
+
return pooled_tokens
|
| 145 |
+
|
| 146 |
+
class DLoader(Dataset):
|
| 147 |
+
def __init__(self, data_list, processor, new_max_token):
|
| 148 |
+
self.data_list = data_list
|
| 149 |
+
self.processor = processor
|
| 150 |
+
self.new_max_token = new_max_token
|
| 151 |
+
|
| 152 |
+
def __len__(self):
|
| 153 |
+
return len(self.data_list)
|
| 154 |
+
|
| 155 |
+
def _load_image(self, name):
|
| 156 |
+
img = Image.open(name).convert("RGB")
|
| 157 |
+
return img, img.size # Also return original image size
|
| 158 |
+
|
| 159 |
+
def __getitem__(self, idx):
|
| 160 |
+
if torch.is_tensor(idx):
|
| 161 |
+
idx = idx.tolist()
|
| 162 |
+
|
| 163 |
+
item = self.data_list[idx]
|
| 164 |
+
org_image, org_image_size = self._load_image(item["original_filename"]) # Get original image size
|
| 165 |
+
org_caption = item["original_caption"]
|
| 166 |
+
|
| 167 |
+
# Always select the segment with the highest similarity score
|
| 168 |
+
segment = max(item["segment"], key=lambda x: x["similarity_score"])
|
| 169 |
+
|
| 170 |
+
seg_image = self._load_image(segment["filename"])[0]
|
| 171 |
+
seg_caption = segment["caption"]
|
| 172 |
+
bbox = segment["bbox_coordinates"]
|
| 173 |
+
|
| 174 |
+
org_data = self.processor(images=org_image, text=org_caption, return_tensors="pt",
|
| 175 |
+
truncation=True, padding="max_length", max_length=self.new_max_token)
|
| 176 |
+
seg_data = self.processor(images=seg_image, text=seg_caption, return_tensors="pt",
|
| 177 |
+
truncation=True, padding="max_length", max_length=self.new_max_token)
|
| 178 |
+
|
| 179 |
+
return (org_data.pixel_values[0], org_data.input_ids[0],
|
| 180 |
+
seg_data.pixel_values[0], seg_data.input_ids[0],
|
| 181 |
+
bbox, org_caption, seg_caption, org_image_size,
|
| 182 |
+
item["original_filename"], segment["filename"])
|
| 183 |
+
|
| 184 |
+
def main(args):
|
| 185 |
+
wandb.init(project="CLIP_Training_real", config=args)
|
| 186 |
+
|
| 187 |
+
fabric = L.Fabric(
|
| 188 |
+
accelerator="cuda",
|
| 189 |
+
devices=args.world_size,
|
| 190 |
+
strategy="ddp",
|
| 191 |
+
precision="bf16"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
fabric.launch()
|
| 195 |
+
fabric.seed_everything(1337 + fabric.global_rank)
|
| 196 |
+
|
| 197 |
+
if fabric.global_rank == 0:
|
| 198 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 199 |
+
|
| 200 |
+
with open(args.dataset) as f:
|
| 201 |
+
train_list = json.load(f)
|
| 202 |
+
|
| 203 |
+
with fabric.device:
|
| 204 |
+
processor = transformers.AutoProcessor.from_pretrained(args.model)
|
| 205 |
+
model = transformers.CLIPModel.from_pretrained(args.model)
|
| 206 |
+
longclip_pos_embeddings(model, args.new_max_token)
|
| 207 |
+
|
| 208 |
+
# Load checkpoint if provided
|
| 209 |
+
if args.ckpt:
|
| 210 |
+
if fabric.global_rank == 0:
|
| 211 |
+
print(f"Loading checkpoint from {args.ckpt}")
|
| 212 |
+
checkpoint = torch.load(args.ckpt, map_location='cpu')
|
| 213 |
+
model.load_state_dict(checkpoint)
|
| 214 |
+
if fabric.global_rank == 0:
|
| 215 |
+
print("Checkpoint loaded successfully")
|
| 216 |
+
|
| 217 |
+
print_trainable_parameters(fabric, model)
|
| 218 |
+
|
| 219 |
+
dataset_train = DLoader(train_list, processor, args.new_max_token)
|
| 220 |
+
|
| 221 |
+
train_loader = torch.utils.data.DataLoader(
|
| 222 |
+
dataset_train, batch_size=args.batch_size,
|
| 223 |
+
num_workers=args.num_workers,
|
| 224 |
+
pin_memory=args.pin_mem,
|
| 225 |
+
drop_last=True,
|
| 226 |
+
shuffle=True,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
train_loader = fabric.setup_dataloaders(train_loader)
|
| 230 |
+
|
| 231 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.init_lr, weight_decay=args.weight_decay)
|
| 232 |
+
model, optimizer = fabric.setup(model, optimizer)
|
| 233 |
+
|
| 234 |
+
train(fabric, model, optimizer, train_loader, processor)
|
| 235 |
+
|
| 236 |
+
def train(fabric: L.Fabric, model: torch.nn.Module, optimizer: torch.optim.Optimizer, train_loader, processor) -> None:
|
| 237 |
+
iter = 0
|
| 238 |
+
total_iter = len(train_loader) * args.epochs
|
| 239 |
+
|
| 240 |
+
# Define MSE Loss
|
| 241 |
+
mse_loss = torch.nn.MSELoss()
|
| 242 |
+
|
| 243 |
+
for epoch in range(args.epochs):
|
| 244 |
+
epoch_loss = 0.0
|
| 245 |
+
epoch_loss_org = 0.0
|
| 246 |
+
epoch_loss_seg = 0.0
|
| 247 |
+
epoch_loss_patch = 0.0
|
| 248 |
+
epoch_loss_text = 0.0
|
| 249 |
+
|
| 250 |
+
for i, samples in enumerate(train_loader):
|
| 251 |
+
# Cosine LR
|
| 252 |
+
lr = (args.init_lr - args.min_lr) * 0.5 * (1.0 + math.cos(math.pi * iter / total_iter)) + args.min_lr
|
| 253 |
+
for param_group in optimizer.param_groups:
|
| 254 |
+
param_group["lr"] = lr
|
| 255 |
+
|
| 256 |
+
org_image, org_text, seg_image, seg_text, bbox, org_caption, seg_caption, org_image_sizes, org_image_paths, seg_image_paths = samples
|
| 257 |
+
|
| 258 |
+
# Get all embeddings including patch tokens and sequence tokens
|
| 259 |
+
outputs = model(pixel_values=torch.cat((org_image, seg_image), dim=0),
|
| 260 |
+
input_ids=torch.cat((org_text, seg_text), dim=0),
|
| 261 |
+
output_hidden_states=True)
|
| 262 |
+
# print(model.text_model.embeddings.position_embedding.weight.requires_grad)
|
| 263 |
+
# Get patch tokens and text tokens
|
| 264 |
+
vision_outputs = model.vision_model(torch.cat((org_image, seg_image), dim=0), output_hidden_states=True)
|
| 265 |
+
text_outputs = model.text_model(torch.cat((org_text, seg_text), dim=0), output_hidden_states=True)
|
| 266 |
+
|
| 267 |
+
# Split embeddings for org and seg
|
| 268 |
+
batch_size = org_image.shape[0]
|
| 269 |
+
org_image_embeds, seg_image_embeds = outputs.image_embeds[:batch_size], outputs.image_embeds[batch_size:]
|
| 270 |
+
org_text_embeds, seg_text_embeds = outputs.text_embeds[:batch_size], outputs.text_embeds[batch_size:]
|
| 271 |
+
|
| 272 |
+
# Get patch tokens and text tokens from the last hidden states
|
| 273 |
+
org_patch_tokens = vision_outputs.hidden_states[-1][:batch_size] # (B, N, D)
|
| 274 |
+
org_text_tokens = text_outputs.hidden_states[-1][:batch_size] # (B, L, D)
|
| 275 |
+
|
| 276 |
+
# Original CLIP loss
|
| 277 |
+
eps = 1e-8
|
| 278 |
+
x_i = batch_align(fabric, F.normalize(outputs.image_embeds + eps))
|
| 279 |
+
x_t = batch_align(fabric, F.normalize(outputs.text_embeds + eps))
|
| 280 |
+
x_i_org, x_i_seg = x_i.chunk(2)
|
| 281 |
+
x_t_org, x_t_seg = x_t.chunk(2)
|
| 282 |
+
|
| 283 |
+
# Compute original losses
|
| 284 |
+
sim_org = model.logit_scale.exp() * x_i_org @ x_t_org.t()
|
| 285 |
+
loss_org = clip_loss(sim_org)
|
| 286 |
+
sim_seg = model.logit_scale.exp() * x_i_seg @ x_t_seg.t()
|
| 287 |
+
loss_seg = clip_loss(sim_seg)
|
| 288 |
+
|
| 289 |
+
# Compute patch-level alignment loss
|
| 290 |
+
patch_pooled = []
|
| 291 |
+
for b in range(batch_size):
|
| 292 |
+
# org_image_sizes is converted to [width_tensor, height_tensor] format
|
| 293 |
+
# Original format: (width, height) tuple
|
| 294 |
+
img_width = org_image_sizes[0][b].item() # b-th element from width tensor
|
| 295 |
+
img_height = org_image_sizes[1][b].item() # b-th element from height tensor
|
| 296 |
+
img_size = (img_width, img_height)
|
| 297 |
+
|
| 298 |
+
pooled = get_patch_tokens_from_bbox(org_patch_tokens[b:b+1],
|
| 299 |
+
bbox,
|
| 300 |
+
b,
|
| 301 |
+
img_size,
|
| 302 |
+
image_size=args.image_size,
|
| 303 |
+
patch_size=16)
|
| 304 |
+
patch_pooled.append(pooled)
|
| 305 |
+
|
| 306 |
+
patch_pooled = torch.cat(patch_pooled, dim=0)
|
| 307 |
+
patch_pooled = model.vision_model.post_layernorm(patch_pooled)
|
| 308 |
+
patch_pooled = model.visual_projection(patch_pooled)
|
| 309 |
+
patch_pooled = F.normalize(patch_pooled + eps, dim=-1)
|
| 310 |
+
seg_image_embeds = F.normalize(seg_image_embeds + eps, dim=-1)
|
| 311 |
+
|
| 312 |
+
# Compute patch alignment loss with cosine similarity directly
|
| 313 |
+
sim_patch = patch_pooled @ seg_image_embeds.t() # removed logit_scale
|
| 314 |
+
patch_diag = torch.diag(sim_patch)
|
| 315 |
+
loss_patch = mse_loss(patch_diag, torch.ones_like(patch_diag))
|
| 316 |
+
|
| 317 |
+
# Compute text-level alignment loss
|
| 318 |
+
text_pooled = []
|
| 319 |
+
for b in range(batch_size):
|
| 320 |
+
#print(f"\nBatch {b} Text Sequences:")
|
| 321 |
+
|
| 322 |
+
# Full token IDs of org_text
|
| 323 |
+
org_tokens = processor(text=org_caption[b],
|
| 324 |
+
return_tensors="pt",
|
| 325 |
+
padding=False,
|
| 326 |
+
truncation=False)
|
| 327 |
+
org_token_ids = org_tokens.input_ids[0]
|
| 328 |
+
|
| 329 |
+
# Full token IDs of seg_text
|
| 330 |
+
seg_tokens = processor(text=seg_caption[b],
|
| 331 |
+
return_tensors="pt",
|
| 332 |
+
padding=False,
|
| 333 |
+
truncation=False)
|
| 334 |
+
seg_token_ids = seg_tokens.input_ids[0]
|
| 335 |
+
|
| 336 |
+
# Decode token IDs to text
|
| 337 |
+
org_tokens_text = processor.tokenizer.convert_ids_to_tokens(org_token_ids)
|
| 338 |
+
seg_tokens_text = processor.tokenizer.convert_ids_to_tokens(seg_token_ids)
|
| 339 |
+
|
| 340 |
+
# Confirm position of tokens extracted by get_text_tokens_from_segment function
|
| 341 |
+
start_idx = len(processor(text=org_caption[b][:org_caption[b].find(seg_caption[b])],
|
| 342 |
+
return_tensors="pt",
|
| 343 |
+
padding=False,
|
| 344 |
+
truncation=False).input_ids[0])
|
| 345 |
+
|
| 346 |
+
end_idx = start_idx + len(seg_tokens.input_ids[0]) - 2 # Exclude CLS, EOS tokens
|
| 347 |
+
|
| 348 |
+
pooled = get_text_tokens_from_segment(org_text_tokens[b:b+1],
|
| 349 |
+
org_caption[b],
|
| 350 |
+
seg_caption[b],
|
| 351 |
+
processor)
|
| 352 |
+
text_pooled.append(pooled)
|
| 353 |
+
text_pooled = torch.cat(text_pooled, dim=0)
|
| 354 |
+
|
| 355 |
+
text_pooled = model.text_model.final_layer_norm(text_pooled)
|
| 356 |
+
text_pooled = model.text_projection(text_pooled)
|
| 357 |
+
text_pooled = F.normalize(text_pooled + eps, dim=-1)
|
| 358 |
+
|
| 359 |
+
seg_text_embeds = F.normalize(seg_text_embeds + eps, dim=-1)
|
| 360 |
+
|
| 361 |
+
# Compute text alignment loss with cosine similarity directly
|
| 362 |
+
sim_text = text_pooled @ seg_text_embeds.t() # removed logit_scale
|
| 363 |
+
text_diag = torch.diag(sim_text)
|
| 364 |
+
loss_text = mse_loss(text_diag, torch.ones_like(text_diag))
|
| 365 |
+
|
| 366 |
+
# Total loss
|
| 367 |
+
loss = loss_org + 0.5 * loss_seg + loss_patch + loss_text
|
| 368 |
+
|
| 369 |
+
epoch_loss += loss.item()
|
| 370 |
+
epoch_loss_org += loss_org.item()
|
| 371 |
+
epoch_loss_seg += loss_seg.item()
|
| 372 |
+
epoch_loss_patch += loss_patch.item()
|
| 373 |
+
epoch_loss_text += loss_text.item()
|
| 374 |
+
|
| 375 |
+
fabric.backward(loss)
|
| 376 |
+
optimizer.step()
|
| 377 |
+
optimizer.zero_grad()
|
| 378 |
+
|
| 379 |
+
if fabric.global_rank == 0:
|
| 380 |
+
wandb.log({
|
| 381 |
+
"iter": iter,
|
| 382 |
+
"lr": lr,
|
| 383 |
+
"loss": loss.item(),
|
| 384 |
+
"loss_org": loss_org.item(),
|
| 385 |
+
"loss_seg": loss_seg.item(),
|
| 386 |
+
"loss_patch": loss_patch.item(),
|
| 387 |
+
"loss_text": loss_text.item(),
|
| 388 |
+
"epoch": epoch,
|
| 389 |
+
"progress": (iter / total_iter) * 100,
|
| 390 |
+
"batch_size": args.batch_size,
|
| 391 |
+
"logit_scale": model.logit_scale.exp().item(),
|
| 392 |
+
"patch_similarity": patch_diag.mean().item(), # average patch similarity
|
| 393 |
+
"text_similarity": text_diag.mean().item(), # average text similarity
|
| 394 |
+
})
|
| 395 |
+
|
| 396 |
+
fabric.print(f"epoch {epoch} iter {iter} ({(iter/total_iter)*100:.4f}%) lr {lr:.6f} "
|
| 397 |
+
f"loss {loss.item():.4f} (org: {loss_org.item():.4f}, seg: {loss_seg.item():.4f}, "
|
| 398 |
+
f"patch: {loss_patch.item():.4f}, text: {loss_text.item():.4f} "
|
| 399 |
+
f"patch_sim: {patch_diag.mean().item():.4f}, text_sim: {text_diag.mean().item():.4f})")
|
| 400 |
+
iter += 1
|
| 401 |
+
|
| 402 |
+
# Calculate and log epoch averages
|
| 403 |
+
avg_epoch_loss = epoch_loss / len(train_loader)
|
| 404 |
+
avg_epoch_loss_org = epoch_loss_org / len(train_loader)
|
| 405 |
+
avg_epoch_loss_seg = epoch_loss_seg / len(train_loader)
|
| 406 |
+
avg_epoch_loss_patch = epoch_loss_patch / len(train_loader)
|
| 407 |
+
avg_epoch_loss_text = epoch_loss_text / len(train_loader)
|
| 408 |
+
|
| 409 |
+
if fabric.global_rank == 0:
|
| 410 |
+
wandb.log({
|
| 411 |
+
"epoch": epoch,
|
| 412 |
+
"avg_epoch_loss": avg_epoch_loss,
|
| 413 |
+
"avg_epoch_loss_org": avg_epoch_loss_org,
|
| 414 |
+
"avg_epoch_loss_seg": avg_epoch_loss_seg,
|
| 415 |
+
"avg_epoch_loss_patch": avg_epoch_loss_patch,
|
| 416 |
+
"avg_epoch_loss_text": avg_epoch_loss_text,
|
| 417 |
+
})
|
| 418 |
+
|
| 419 |
+
# Save model weights
|
| 420 |
+
save_path = os.path.join(args.output_dir,
|
| 421 |
+
f"GOAL_12_{os.path.splitext(os.path.basename(args.model))[0]}_"
|
| 422 |
+
f"{os.path.splitext(os.path.basename(args.dataset))[0]}_{epoch+1}_{args.image_size}.pth")
|
| 423 |
+
|
| 424 |
+
fabric.barrier()
|
| 425 |
+
if fabric.global_rank == 0:
|
| 426 |
+
model_state_dict = model.state_dict()
|
| 427 |
+
cpu_state_dict = {k: v.cpu() for k, v in model_state_dict.items()}
|
| 428 |
+
torch.save(cpu_state_dict, save_path)
|
| 429 |
+
fabric.print(f"Model saved to {save_path}")
|
| 430 |
+
fabric.barrier()
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def get_args_parser():
|
| 434 |
+
parser = argparse.ArgumentParser('CLIP Training', add_help=False)
|
| 435 |
+
parser.add_argument('--batch_size', default=16, type=int,
|
| 436 |
+
help='Batch size per GPU')
|
| 437 |
+
parser.add_argument('--epochs', default=10, type=int)
|
| 438 |
+
parser.add_argument('--image_size', default=224, type=int)
|
| 439 |
+
parser.add_argument('--new_max_token', default=248, type=int)
|
| 440 |
+
parser.add_argument('--dataset', default='datasets/docci_segment_sim_bbox_del_org.json', type=str)
|
| 441 |
+
parser.add_argument('--model', default='openai/clip-vit-base-patch16', type=str)
|
| 442 |
+
parser.add_argument('--weight_decay', type=float, default=0.05)
|
| 443 |
+
parser.add_argument('--init_lr', type=float, default=5e-6, metavar='LR') # originally 5e-6
|
| 444 |
+
parser.add_argument('--min_lr', type=float, default=0, metavar='LR')
|
| 445 |
+
parser.add_argument('--output_dir', default='finetune_out_SA_1B_100k_plus_docci/goal_bbox_local_token_align_batch16_only_max_pair_base16_patch16_real',
|
| 446 |
+
help='path where to save, empty for no saving')
|
| 447 |
+
parser.add_argument('--save_interval', default=1, type=int)
|
| 448 |
+
parser.add_argument('--seed', default=0, type=int)
|
| 449 |
+
parser.add_argument('--num_workers', default=10, type=int)
|
| 450 |
+
parser.add_argument('--pin_mem', action='store_true',
|
| 451 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
| 452 |
+
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
|
| 453 |
+
parser.add_argument('--wandb_project', type=str, default='CLIP_Training', help='wandb project name')
|
| 454 |
+
parser.add_argument('--ckpt', type=str, default=None, help='path to checkpoint file')
|
| 455 |
+
parser.set_defaults(pin_mem=True)
|
| 456 |
+
parser.set_defaults(pin_mem=True)
|
| 457 |
+
|
| 458 |
+
# distributed training parameters
|
| 459 |
+
parser.add_argument('--world_size', default=1, type=int,
|
| 460 |
+
help='number of distributed processes')
|
| 461 |
+
|
| 462 |
+
return parser
|
| 463 |
+
|
| 464 |
+
if __name__ == "__main__":
|
| 465 |
+
args = get_args_parser()
|
| 466 |
+
args = args.parse_args()
|
| 467 |
+
if args.output_dir:
|
| 468 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 469 |
+
main(args)
|
GOAL_github/mAP_goal_jointtest.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import lightning as L
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import transformers
|
| 11 |
+
from torch.utils.data import DataLoader, Dataset, Subset
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from utils.func import *
|
| 15 |
+
from utils.transforms import *
|
| 16 |
+
|
| 17 |
+
# Hyperparameters
|
| 18 |
+
micro_batch_size = 32
|
| 19 |
+
devices = 1
|
| 20 |
+
num_workers = 1
|
| 21 |
+
|
| 22 |
+
class QueryLoader(Dataset):
|
| 23 |
+
def __init__(self, data_list, processor):
|
| 24 |
+
self.data_list = data_list
|
| 25 |
+
self.processor = processor
|
| 26 |
+
self.image_to_segs, self.org_images, self.filename_to_caption = self._create_mappings()
|
| 27 |
+
|
| 28 |
+
def _create_mappings(self):
|
| 29 |
+
image_to_segs = {}
|
| 30 |
+
org_images = set()
|
| 31 |
+
filename_to_caption = {}
|
| 32 |
+
for item in self.data_list:
|
| 33 |
+
filename = item['filename']
|
| 34 |
+
filename_to_caption[filename] = item['caption']
|
| 35 |
+
if 'segment_with_background' in filename:
|
| 36 |
+
org_filename = get_org_filename(filename)
|
| 37 |
+
if org_filename not in image_to_segs:
|
| 38 |
+
image_to_segs[org_filename] = []
|
| 39 |
+
image_to_segs[org_filename].append(item)
|
| 40 |
+
else:
|
| 41 |
+
org_images.add(filename)
|
| 42 |
+
image_to_segs[filename] = []
|
| 43 |
+
return image_to_segs, org_images, filename_to_caption
|
| 44 |
+
|
| 45 |
+
def __len__(self):
|
| 46 |
+
return len(self.data_list)
|
| 47 |
+
|
| 48 |
+
def _load_image(self, id: int):
|
| 49 |
+
return Image.open(self.data_list[id]["filename"]).convert("RGB")
|
| 50 |
+
|
| 51 |
+
def _load_target(self, id: int):
|
| 52 |
+
return self.data_list[id]["caption"]
|
| 53 |
+
|
| 54 |
+
def __getitem__(self, idx):
|
| 55 |
+
if torch.is_tensor(idx):
|
| 56 |
+
idx = idx.tolist()
|
| 57 |
+
image = self._load_image(idx)
|
| 58 |
+
caption = self._load_target(idx)
|
| 59 |
+
data = self.processor(images=image, text=caption, return_tensors="pt", truncation=True, padding="max_length", max_length=args.new_max_token)
|
| 60 |
+
return data.pixel_values[0], data.input_ids[0], self.data_list[idx]["filename"]
|
| 61 |
+
|
| 62 |
+
def process_chunk(fabric, model, data_loader):
|
| 63 |
+
images = []
|
| 64 |
+
texts = []
|
| 65 |
+
filenames = []
|
| 66 |
+
for samples in data_loader:
|
| 67 |
+
image, text, filename = samples
|
| 68 |
+
with torch.no_grad():
|
| 69 |
+
x = model(pixel_values=image.to(fabric.device), input_ids=text.to(fabric.device))
|
| 70 |
+
x_i = F.normalize(x.image_embeds)
|
| 71 |
+
x_t = F.normalize(x.text_embeds)
|
| 72 |
+
images.append(x_i)
|
| 73 |
+
texts.append(x_t)
|
| 74 |
+
filenames.extend(filename)
|
| 75 |
+
return torch.cat(images), torch.cat(texts), filenames
|
| 76 |
+
|
| 77 |
+
def compute_similarity(images, texts):
|
| 78 |
+
return torch.mm(images, texts.t())
|
| 79 |
+
|
| 80 |
+
def compute_ap(ranks, relevant_items, k=None):
|
| 81 |
+
if not relevant_items:
|
| 82 |
+
return 0.0
|
| 83 |
+
score = 0.0
|
| 84 |
+
num_hits = 0.0
|
| 85 |
+
for i, item in enumerate(ranks[:k] if k else ranks):
|
| 86 |
+
if item in relevant_items:
|
| 87 |
+
num_hits += 1.0
|
| 88 |
+
score += num_hits / (i + 1.0)
|
| 89 |
+
return score / len(relevant_items)
|
| 90 |
+
|
| 91 |
+
def get_org_filename(filename):
|
| 92 |
+
if 'segment_with_background' in filename:
|
| 93 |
+
parts = filename.split('/')
|
| 94 |
+
org_filename = parts[-2].split('_results')[0]
|
| 95 |
+
if not org_filename.endswith('.jpg'):
|
| 96 |
+
org_filename += '.jpg'
|
| 97 |
+
return org_filename
|
| 98 |
+
return filename
|
| 99 |
+
|
| 100 |
+
def get_relevant_items(filename, all_filenames, image_to_segs, org_images, filename_to_caption):
|
| 101 |
+
org_filename = get_org_filename(filename)
|
| 102 |
+
|
| 103 |
+
if filename in org_images:
|
| 104 |
+
# Query is an org image
|
| 105 |
+
relevant_items = [all_filenames.index(seg['filename']) for seg in image_to_segs[filename]]
|
| 106 |
+
else:
|
| 107 |
+
# Query is a seg image
|
| 108 |
+
relevant_items = []
|
| 109 |
+
if org_filename in all_filenames:
|
| 110 |
+
relevant_items.append(all_filenames.index(org_filename))
|
| 111 |
+
if filename in all_filenames:
|
| 112 |
+
relevant_items.append(all_filenames.index(filename))
|
| 113 |
+
|
| 114 |
+
return relevant_items
|
| 115 |
+
|
| 116 |
+
def get_relevant_captions(filename, image_to_segs, org_images, filename_to_caption):
|
| 117 |
+
org_filename = get_org_filename(filename)
|
| 118 |
+
|
| 119 |
+
if filename in org_images:
|
| 120 |
+
# Query is an org image
|
| 121 |
+
relevant_captions = [filename_to_caption[filename]] # 원본 이미지의 캡션 추가
|
| 122 |
+
relevant_captions.extend([seg['caption'] for seg in image_to_segs[filename]])
|
| 123 |
+
else:
|
| 124 |
+
# Query is a seg image
|
| 125 |
+
relevant_captions = [filename_to_caption[filename]]
|
| 126 |
+
if org_filename in filename_to_caption:
|
| 127 |
+
relevant_captions.append(filename_to_caption[org_filename])
|
| 128 |
+
|
| 129 |
+
return relevant_captions
|
| 130 |
+
|
| 131 |
+
def get_relevant_items_for_text(query_caption, all_filenames, image_to_segs, org_images, filename_to_caption):
|
| 132 |
+
relevant_items = []
|
| 133 |
+
for filename, caption in filename_to_caption.items():
|
| 134 |
+
if caption == query_caption:
|
| 135 |
+
if filename in org_images:
|
| 136 |
+
# 쿼리가 org 이미지의 캡션인 경우
|
| 137 |
+
relevant_items.append(all_filenames.index(filename)) # org 이미지
|
| 138 |
+
relevant_items.extend([all_filenames.index(seg['filename']) for seg in image_to_segs[filename]]) # seg 이미지들
|
| 139 |
+
else:
|
| 140 |
+
# 쿼리가 seg 이미지의 캡션인 경우
|
| 141 |
+
relevant_items.append(all_filenames.index(filename)) # seg 이미지
|
| 142 |
+
org_filename = get_org_filename(filename)
|
| 143 |
+
if org_filename in all_filenames:
|
| 144 |
+
relevant_items.append(all_filenames.index(org_filename)) # org 이미지
|
| 145 |
+
return list(set(relevant_items)) # 중복 제거
|
| 146 |
+
|
| 147 |
+
@torch.no_grad()
|
| 148 |
+
def test(fabric: L.Fabric, model: torch.nn.Module, query_loader, k=None) -> torch.Tensor:
|
| 149 |
+
fabric.print("Testing ...")
|
| 150 |
+
|
| 151 |
+
chunk_size = 5000
|
| 152 |
+
dataset_size = len(query_loader.dataset)
|
| 153 |
+
all_images = []
|
| 154 |
+
all_texts = []
|
| 155 |
+
all_filenames = []
|
| 156 |
+
|
| 157 |
+
for start_idx in range(0, dataset_size, chunk_size):
|
| 158 |
+
end_idx = min(start_idx + chunk_size, dataset_size)
|
| 159 |
+
chunk_dataset = Subset(query_loader.dataset, range(start_idx, end_idx))
|
| 160 |
+
chunk_loader = DataLoader(chunk_dataset, batch_size=query_loader.batch_size, shuffle=False, num_workers=query_loader.num_workers)
|
| 161 |
+
|
| 162 |
+
chunk_images, chunk_texts, chunk_filenames = process_chunk(fabric, model, chunk_loader)
|
| 163 |
+
|
| 164 |
+
all_images.append(chunk_images)
|
| 165 |
+
all_texts.append(chunk_texts)
|
| 166 |
+
all_filenames.extend(chunk_filenames)
|
| 167 |
+
|
| 168 |
+
torch.cuda.empty_cache()
|
| 169 |
+
|
| 170 |
+
all_images = torch.cat(all_images)
|
| 171 |
+
all_texts = torch.cat(all_texts)
|
| 172 |
+
|
| 173 |
+
similarity = compute_similarity(all_images, all_texts)
|
| 174 |
+
|
| 175 |
+
image_to_segs = query_loader.dataset.image_to_segs
|
| 176 |
+
org_images = query_loader.dataset.org_images
|
| 177 |
+
filename_to_caption = query_loader.dataset.filename_to_caption
|
| 178 |
+
mAP_i2t = 0.0
|
| 179 |
+
mAP_t2i = 0.0
|
| 180 |
+
|
| 181 |
+
# Image to Text
|
| 182 |
+
for i, filename in enumerate(all_filenames):
|
| 183 |
+
relevant_captions = get_relevant_captions(filename, image_to_segs, org_images, filename_to_caption)
|
| 184 |
+
|
| 185 |
+
i2t_ranks = torch.argsort(similarity[i], descending=True).tolist()
|
| 186 |
+
i2t_relevant = [idx for idx, fn in enumerate(all_filenames) if filename_to_caption[fn] in relevant_captions]
|
| 187 |
+
mAP_i2t += compute_ap(i2t_ranks, i2t_relevant, k)
|
| 188 |
+
|
| 189 |
+
# Text to Image
|
| 190 |
+
unique_captions = set(filename_to_caption.values())
|
| 191 |
+
for caption in unique_captions:
|
| 192 |
+
relevant_items = get_relevant_items_for_text(caption, all_filenames, image_to_segs, org_images, filename_to_caption)
|
| 193 |
+
|
| 194 |
+
caption_index = all_filenames.index(next(filename for filename, cap in filename_to_caption.items() if cap == caption))
|
| 195 |
+
t2i_ranks = torch.argsort(similarity[:, caption_index], descending=True).tolist()
|
| 196 |
+
mAP_t2i += compute_ap(t2i_ranks, relevant_items, k)
|
| 197 |
+
|
| 198 |
+
mAP_i2t /= len(all_filenames)
|
| 199 |
+
mAP_t2i /= len(unique_captions)
|
| 200 |
+
|
| 201 |
+
fabric.print(f"mAP@{k if k else 'all'} - Text-to-Image: {mAP_t2i:.4f} & Image-to-Text: {mAP_i2t:.4f}")
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def main(args):
|
| 205 |
+
fabric = L.Fabric(
|
| 206 |
+
accelerator="cuda",
|
| 207 |
+
devices=devices,
|
| 208 |
+
precision="bf16-mixed"
|
| 209 |
+
)
|
| 210 |
+
fabric.launch()
|
| 211 |
+
fabric.seed_everything(1337 + fabric.global_rank)
|
| 212 |
+
|
| 213 |
+
if args.model == 'L-336':
|
| 214 |
+
args.model = 'openai/clip-vit-large-patch14-336'
|
| 215 |
+
elif args.model == 'L':
|
| 216 |
+
args.model = 'openai/clip-vit-large-patch14'
|
| 217 |
+
elif args.model == 'B':
|
| 218 |
+
args.model = 'openai/clip-vit-base-patch16'
|
| 219 |
+
elif args.model == 'G':
|
| 220 |
+
args.model = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
|
| 221 |
+
|
| 222 |
+
with fabric.device:
|
| 223 |
+
processor = transformers.AutoProcessor.from_pretrained(args.model)
|
| 224 |
+
model = transformers.AutoModel.from_pretrained(args.model).bfloat16()
|
| 225 |
+
longclip_pos_embeddings(model, args.new_max_token)
|
| 226 |
+
model.load_state_dict(torch.load(args.ckpt), strict=False)
|
| 227 |
+
|
| 228 |
+
if args.dataset == 'docci':
|
| 229 |
+
query_list = 'datasets/docci_test_joint_sim_max_1:1.json'
|
| 230 |
+
elif args.dataset == 'DCI':
|
| 231 |
+
query_list = 'datasets/DCI_test_joint_sim_max_1:1.json'
|
| 232 |
+
|
| 233 |
+
with open(query_list) as f:
|
| 234 |
+
query_list = json.loads(f.read())
|
| 235 |
+
|
| 236 |
+
args.query_list = query_list
|
| 237 |
+
|
| 238 |
+
query_dataset = QueryLoader(query_list, processor)
|
| 239 |
+
query_loader = DataLoader(query_dataset, num_workers=num_workers, batch_size=micro_batch_size, shuffle=False, drop_last=False, pin_memory=False)
|
| 240 |
+
query_loader = fabric.setup_dataloaders(query_loader)
|
| 241 |
+
|
| 242 |
+
model.eval().to(fabric.device)
|
| 243 |
+
test(fabric, model, query_loader, args.k)
|
| 244 |
+
|
| 245 |
+
if __name__ == "__main__":
|
| 246 |
+
torch.set_float32_matmul_precision("high")
|
| 247 |
+
|
| 248 |
+
parser = argparse.ArgumentParser()
|
| 249 |
+
parser.add_argument("--dataset", type=str, default='docci')
|
| 250 |
+
parser.add_argument('--new_max_token', default=248, type=int)
|
| 251 |
+
parser.add_argument("--model", type=str, default='B')
|
| 252 |
+
parser.add_argument("--ckpt", type=str, default='')
|
| 253 |
+
parser.add_argument("--k", type=int, default=10, help="Limit rank calculation to top K results. Use None for all ranks.")
|
| 254 |
+
args = parser.parse_args()
|
| 255 |
+
|
| 256 |
+
main(args)
|
GOAL_github/retrieval_goal.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import lightning as L
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import transformers
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from torch.utils.data import Dataset
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from utils.func import *
|
| 16 |
+
from utils.transforms import *
|
| 17 |
+
|
| 18 |
+
# Hyperparameters
|
| 19 |
+
micro_batch_size = 32
|
| 20 |
+
devices = 1
|
| 21 |
+
num_workers = 1
|
| 22 |
+
|
| 23 |
+
class QueryLoader(Dataset):
|
| 24 |
+
def __init__(self, data_list, processor):
|
| 25 |
+
self.data_list = data_list
|
| 26 |
+
self.processor = processor
|
| 27 |
+
|
| 28 |
+
def __len__(self):
|
| 29 |
+
return len(self.data_list)
|
| 30 |
+
|
| 31 |
+
def _load_image(self, id: int):
|
| 32 |
+
return Image.open(self.data_list[id]["filename"]).convert("RGB")
|
| 33 |
+
|
| 34 |
+
def _load_target(self, id: int):
|
| 35 |
+
return self.data_list[id]["caption"]
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, idx):
|
| 38 |
+
if torch.is_tensor(idx):
|
| 39 |
+
idx = idx.tolist()
|
| 40 |
+
image = self._load_image(idx)
|
| 41 |
+
caption = self._load_target(idx)
|
| 42 |
+
data = self.processor(images=image, text=caption, return_tensors="pt", truncation = True, padding = "max_length", max_length=args.new_max_token)
|
| 43 |
+
|
| 44 |
+
return data.pixel_values[0], data.input_ids[0]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def main(args):
|
| 48 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # GPU 1번 인덱스만 사용
|
| 49 |
+
fabric = L.Fabric(
|
| 50 |
+
accelerator="cuda",
|
| 51 |
+
devices=devices,
|
| 52 |
+
precision="bf16-mixed" # "32"에서 "bf16-mixed"로 변경
|
| 53 |
+
)
|
| 54 |
+
fabric.launch()
|
| 55 |
+
fabric.seed_everything(1337 + fabric.global_rank)
|
| 56 |
+
|
| 57 |
+
if args.model=='L-336':
|
| 58 |
+
args.model = 'openai/clip-vit-large-patch14-336'
|
| 59 |
+
elif args.model=='L':
|
| 60 |
+
args.model = 'openai/clip-vit-large-patch14'
|
| 61 |
+
elif args.model=='B':
|
| 62 |
+
args.model = 'openai/clip-vit-base-patch16'
|
| 63 |
+
elif args.model=='G':
|
| 64 |
+
args.model = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
|
| 65 |
+
|
| 66 |
+
with fabric.device:
|
| 67 |
+
processor = transformers.AutoProcessor.from_pretrained(args.model)
|
| 68 |
+
model = transformers.AutoModel.from_pretrained(args.model).bfloat16()
|
| 69 |
+
longclip_pos_embeddings(model, args.new_max_token)
|
| 70 |
+
if args.ckpt: # ckpt가 제공된 경우에만 로드
|
| 71 |
+
model.load_state_dict(torch.load(args.ckpt), strict=False)
|
| 72 |
+
|
| 73 |
+
if args.dataset == 'docci':
|
| 74 |
+
query_list = 'datasets/docci_test.json'
|
| 75 |
+
elif args.dataset =='coco':
|
| 76 |
+
query_list = 'datasets/coco_test.json'
|
| 77 |
+
elif args.dataset =='flickr30k':
|
| 78 |
+
query_list = 'datasets/flickr30k_test.json'
|
| 79 |
+
elif args.dataset =='DCI':
|
| 80 |
+
query_list = 'datasets/DCI_test.json'
|
| 81 |
+
elif args.dataset =='urban':
|
| 82 |
+
query_list = 'datasets/urban_dataset_test.json'
|
| 83 |
+
elif args.dataset =='sharegpt4v':
|
| 84 |
+
query_list = 'datasets/sharegpt4v_test.json'
|
| 85 |
+
|
| 86 |
+
with open(query_list) as f:
|
| 87 |
+
query_list = json.loads(f.read())
|
| 88 |
+
|
| 89 |
+
args.query_list = query_list
|
| 90 |
+
|
| 91 |
+
query_dataset = QueryLoader(query_list, processor)
|
| 92 |
+
query_loader = DataLoader(query_dataset, num_workers=num_workers, batch_size=micro_batch_size, shuffle=False, drop_last=False, pin_memory=False)
|
| 93 |
+
query_loader = fabric.setup_dataloaders(query_loader)
|
| 94 |
+
|
| 95 |
+
model.eval().to(fabric.device)
|
| 96 |
+
test(fabric, model, query_loader)
|
| 97 |
+
|
| 98 |
+
def compute_AP_and_recall_at_Ks(similarity, label_matrix, Ks):
|
| 99 |
+
# Sort gallery indices based on similarity
|
| 100 |
+
sorted_indices = torch.argsort(similarity, descending=True)
|
| 101 |
+
# Initialize results
|
| 102 |
+
results = {K: {'AP': 0.0, 'recall': 0.0, 'relevant_items': 0} for K in Ks}
|
| 103 |
+
total_relevant_items = label_matrix.sum().item()
|
| 104 |
+
for i, idx in enumerate(sorted_indices):
|
| 105 |
+
if label_matrix[idx]:
|
| 106 |
+
for K in Ks:
|
| 107 |
+
if i < K:
|
| 108 |
+
results[K]['relevant_items'] += 1
|
| 109 |
+
precision = results[K]['relevant_items'] / (i + 1)
|
| 110 |
+
results[K]['AP'] += precision
|
| 111 |
+
results[K]['recall'] = results[K]['relevant_items'] / total_relevant_items
|
| 112 |
+
for K in Ks:
|
| 113 |
+
results[K]['AP'] /= total_relevant_items
|
| 114 |
+
return results
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@torch.no_grad()
|
| 118 |
+
def test(fabric: L.Fabric, model: torch.nn.Module, query_loader) -> torch.Tensor:
|
| 119 |
+
fabric.print("Testing ...")
|
| 120 |
+
|
| 121 |
+
images = torch.tensor([], dtype=torch.float32).to(fabric.device)
|
| 122 |
+
texts = torch.tensor([], dtype=torch.float32).to(fabric.device)
|
| 123 |
+
|
| 124 |
+
for samples in query_loader:
|
| 125 |
+
image, text = samples
|
| 126 |
+
|
| 127 |
+
x = model(pixel_values=image, input_ids=text)
|
| 128 |
+
|
| 129 |
+
x_i = F.normalize(x.image_embeds)
|
| 130 |
+
x_t = F.normalize(x.text_embeds)
|
| 131 |
+
|
| 132 |
+
images = torch.cat((images,x_i), dim=0)
|
| 133 |
+
texts = torch.cat((texts,x_t), dim=0)
|
| 134 |
+
|
| 135 |
+
# Calculate cosine similarity
|
| 136 |
+
similarity = torch.mm(images, texts.t())
|
| 137 |
+
# Image to Text (I2T)
|
| 138 |
+
sorted_indices_i2t = torch.argsort(similarity, descending=True)
|
| 139 |
+
correct_indices = torch.arange(images.shape[0]).to(fabric.device)
|
| 140 |
+
ranks_i2t = (sorted_indices_i2t == correct_indices[:, None]).nonzero(as_tuple=True)[1]
|
| 141 |
+
# Text to Image (T2I)
|
| 142 |
+
sorted_indices_t2i = torch.argsort(similarity.t(), descending=True)
|
| 143 |
+
ranks_t2i = (sorted_indices_t2i == correct_indices[:, None]).nonzero(as_tuple=True)[1]
|
| 144 |
+
# Calculate recall at different ranks for I2T
|
| 145 |
+
recall_i2t_1 = (ranks_i2t < 1).float().mean().item() * 100
|
| 146 |
+
recall_i2t_5 = (ranks_i2t < 5).float().mean().item() * 100
|
| 147 |
+
recall_i2t_25 = (ranks_i2t < 25).float().mean().item() * 100
|
| 148 |
+
recall_i2t_50 = (ranks_i2t < 50).float().mean().item() * 100
|
| 149 |
+
# Calculate recall at different ranks for T2I
|
| 150 |
+
recall_t2i_1 = (ranks_t2i < 1).float().mean().item() * 100
|
| 151 |
+
recall_t2i_5 = (ranks_t2i < 5).float().mean().item() * 100
|
| 152 |
+
recall_t2i_25 = (ranks_t2i < 25).float().mean().item() * 100
|
| 153 |
+
recall_t2i_50 = (ranks_t2i < 50).float().mean().item() * 100
|
| 154 |
+
# Print recall percentages for T2I
|
| 155 |
+
fabric.print(f"Text-to-Image: {recall_t2i_1:.2f} & {recall_t2i_5:.2f} & {recall_t2i_25:.2f} & {recall_t2i_50:.2f}")
|
| 156 |
+
# Print recall percentages for I2T
|
| 157 |
+
fabric.print(f"Image-to-Text: {recall_i2t_1:.2f} & {recall_i2t_5:.2f} & {recall_i2t_25:.2f} & {recall_i2t_50:.2f}")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
torch.set_float32_matmul_precision("high")
|
| 163 |
+
|
| 164 |
+
parser = argparse.ArgumentParser()
|
| 165 |
+
parser.add_argument("--dataset", type=str, default='urban')
|
| 166 |
+
parser.add_argument('--new_max_token', default=248, type=int)
|
| 167 |
+
parser.add_argument("--model", type=str, default='B')
|
| 168 |
+
parser.add_argument("--ckpt", type=str, default='')
|
| 169 |
+
args = parser.parse_args()
|
| 170 |
+
|
| 171 |
+
main(args)
|
GOAL_github/utils/__pycache__/easydict.cpython-39.pyc
ADDED
|
Binary file (3.64 kB). View file
|
|
|
GOAL_github/utils/__pycache__/func.cpython-310.pyc
ADDED
|
Binary file (3.92 kB). View file
|
|
|
GOAL_github/utils/__pycache__/func.cpython-311.pyc
ADDED
|
Binary file (9.11 kB). View file
|
|
|
GOAL_github/utils/__pycache__/func.cpython-39.pyc
ADDED
|
Binary file (3.86 kB). View file
|
|
|
GOAL_github/utils/__pycache__/randaugment.cpython-310.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
GOAL_github/utils/__pycache__/randaugment.cpython-311.pyc
ADDED
|
Binary file (18.5 kB). View file
|
|
|
GOAL_github/utils/__pycache__/randaugment.cpython-39.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
GOAL_github/utils/__pycache__/transforms.cpython-310.pyc
ADDED
|
Binary file (3.55 kB). View file
|
|
|
GOAL_github/utils/__pycache__/transforms.cpython-311.pyc
ADDED
|
Binary file (5.44 kB). View file
|
|
|
GOAL_github/utils/__pycache__/transforms.cpython-39.pyc
ADDED
|
Binary file (3.61 kB). View file
|
|
|
GOAL_github/utils/func.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import cv2
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
def interpolate_pos_embeddings(model, new_image_size):
|
| 7 |
+
vision_model = model.vision_model
|
| 8 |
+
patch_size = vision_model.config.patch_size
|
| 9 |
+
num_patches = (new_image_size // patch_size) ** 2 + 1
|
| 10 |
+
# Extract and interpolate positional embeddings
|
| 11 |
+
pos_embeddings = vision_model.embeddings.position_embedding.weight
|
| 12 |
+
pos_embeddings = pos_embeddings.unsqueeze(0).permute(0, 2, 1) # Convert to 1xCxN format
|
| 13 |
+
pos_embeddings = torch.nn.functional.interpolate(
|
| 14 |
+
pos_embeddings, size=(num_patches), mode='nearest'
|
| 15 |
+
).squeeze(0).permute(1, 0) # Convert back to NxC format
|
| 16 |
+
pos_embeddings = pos_embeddings.contiguous() # Ensure contiguous
|
| 17 |
+
vision_model.embeddings.position_embedding.weight = torch.nn.Parameter(pos_embeddings)
|
| 18 |
+
# Set position_ids
|
| 19 |
+
if hasattr(vision_model.embeddings, 'position_ids'):
|
| 20 |
+
vision_model.embeddings.position_ids = torch.arange(0, num_patches).unsqueeze(0)
|
| 21 |
+
else:
|
| 22 |
+
vision_model.register_buffer('position_ids', torch.arange(0, num_patches).unsqueeze(0))
|
| 23 |
+
|
| 24 |
+
def interpolate_text_pos_embeddings(model, new_max_token):
|
| 25 |
+
text_model = model.text_model
|
| 26 |
+
# Extract and interpolate positional embeddings
|
| 27 |
+
pos_embeddings = text_model.embeddings.position_embedding.weight
|
| 28 |
+
pos_embeddings = pos_embeddings.unsqueeze(0).permute(0, 2, 1) # Convert to 1xCxN format
|
| 29 |
+
|
| 30 |
+
# Interpolate the position embeddings to the new maximum token length
|
| 31 |
+
pos_embeddings = torch.nn.functional.interpolate(
|
| 32 |
+
pos_embeddings, size=(new_max_token), mode='nearest'
|
| 33 |
+
).squeeze(0).permute(1, 0) # Convert back to NxC format
|
| 34 |
+
|
| 35 |
+
pos_embeddings = pos_embeddings.contiguous() # Ensure contiguous
|
| 36 |
+
text_model.embeddings.position_embedding.weight = torch.nn.Parameter(pos_embeddings)
|
| 37 |
+
|
| 38 |
+
# Set position_ids if the model uses them
|
| 39 |
+
if hasattr(text_model.embeddings, 'position_ids'):
|
| 40 |
+
text_model.embeddings.position_ids = torch.arange(0, new_max_token).unsqueeze(0)
|
| 41 |
+
else:
|
| 42 |
+
text_model.register_buffer('position_ids', torch.arange(0, new_max_token).unsqueeze(0))
|
| 43 |
+
|
| 44 |
+
def longclip_pos_embeddings(model, new_max_token):
|
| 45 |
+
text_model = model.text_model
|
| 46 |
+
# Extract positional embeddings
|
| 47 |
+
pos_embeddings_pre = text_model.embeddings.position_embedding.weight
|
| 48 |
+
length, dim = pos_embeddings_pre.shape
|
| 49 |
+
keep_len = 20
|
| 50 |
+
new_length = 4*length - 3*keep_len
|
| 51 |
+
if new_length < new_max_token:
|
| 52 |
+
raise ValueError("new_max_token is too large")
|
| 53 |
+
pos_embeddings_new = torch.zeros([new_max_token, dim], dtype=pos_embeddings_pre.dtype)
|
| 54 |
+
for i in range(keep_len):
|
| 55 |
+
pos_embeddings_new[i] = pos_embeddings_pre[i]
|
| 56 |
+
for i in range(length-1-keep_len):
|
| 57 |
+
pos_embeddings_new[4*i + keep_len] = pos_embeddings_pre[i + keep_len]
|
| 58 |
+
pos_embeddings_new[4*i + 1 + keep_len] = 3*pos_embeddings_pre[i + keep_len]/4 + 1*pos_embeddings_pre[i+1+keep_len]/4
|
| 59 |
+
pos_embeddings_new[4*i + 2+keep_len] = 2*pos_embeddings_pre[i+keep_len]/4 + 2*pos_embeddings_pre[i+1+keep_len]/4
|
| 60 |
+
pos_embeddings_new[4*i + 3+keep_len] = 1*pos_embeddings_pre[i+keep_len]/4 + 3*pos_embeddings_pre[i+1+keep_len]/4
|
| 61 |
+
pos_embeddings_new[4*length -3*keep_len - 4] = pos_embeddings_pre[length-1] + 0*(pos_embeddings_pre[length-1] - pos_embeddings_pre[length-2])/4
|
| 62 |
+
pos_embeddings_new[4*length -3*keep_len - 3] = pos_embeddings_pre[length-1] + 1*(pos_embeddings_pre[length-1] - pos_embeddings_pre[length-2])/4
|
| 63 |
+
pos_embeddings_new[4*length -3*keep_len - 2] = pos_embeddings_pre[length-1] + 2*(pos_embeddings_pre[length-1] - pos_embeddings_pre[length-2])/4
|
| 64 |
+
pos_embeddings_new[4*length -3*keep_len - 1] = pos_embeddings_pre[length-1] + 3*(pos_embeddings_pre[length-1] - pos_embeddings_pre[length-2])/4
|
| 65 |
+
text_model.embeddings.position_embedding.weight = torch.nn.Parameter(pos_embeddings_new)
|
| 66 |
+
# Set position_ids if the model uses them
|
| 67 |
+
if hasattr(text_model.embeddings, 'position_ids'):
|
| 68 |
+
text_model.embeddings.position_ids = torch.arange(0, new_max_token).unsqueeze(0)
|
| 69 |
+
else:
|
| 70 |
+
text_model.register_buffer('position_ids', torch.arange(0, new_max_token).unsqueeze(0))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def average_pool(last_hidden_states, attention_mask):
|
| 74 |
+
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
| 75 |
+
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
| 76 |
+
|
| 77 |
+
def last_token_pool(last_hidden_states, attention_mask):
|
| 78 |
+
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
| 79 |
+
if left_padding:
|
| 80 |
+
return last_hidden_states[:, -1]
|
| 81 |
+
else:
|
| 82 |
+
sequence_lengths = attention_mask.sum(dim=1) - 1
|
| 83 |
+
batch_size = last_hidden_states.shape[0]
|
| 84 |
+
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
| 85 |
+
|
| 86 |
+
def batch_align(fabric, x):
|
| 87 |
+
x = fabric.all_gather(x, sync_grads=True)
|
| 88 |
+
return x.view(x.shape[0]*x.shape[1], -1)
|
| 89 |
+
|
| 90 |
+
cls_criterion = torch.nn.CrossEntropyLoss()
|
| 91 |
+
|
| 92 |
+
def clip_loss(logits):
|
| 93 |
+
gt = torch.arange(len(logits),dtype=torch.long, device=logits.device)
|
| 94 |
+
return (cls_criterion(logits, gt) + cls_criterion(logits.t(), gt))/2.0
|
| 95 |
+
|
| 96 |
+
def print_trainable_parameters(fabric, model):
|
| 97 |
+
trainable_params = 0
|
| 98 |
+
all_param = 0
|
| 99 |
+
for _, param in model.named_parameters():
|
| 100 |
+
all_param += param.numel()
|
| 101 |
+
if param.requires_grad:
|
| 102 |
+
trainable_params += param.numel()
|
| 103 |
+
fabric.print(
|
| 104 |
+
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
|
| 105 |
+
)
|
| 106 |
+
fabric.print('Memory load of model: {} bytes'.format(torch.cuda.memory_allocated()))
|
GOAL_github/utils/randaugment.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code from salesforce/BLIP
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
## aug functions
|
| 8 |
+
def identity_func(img):
|
| 9 |
+
return img
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def autocontrast_func(img, cutoff=0):
|
| 13 |
+
"""
|
| 14 |
+
same output as PIL.ImageOps.autocontrast
|
| 15 |
+
"""
|
| 16 |
+
n_bins = 256
|
| 17 |
+
|
| 18 |
+
def tune_channel(ch):
|
| 19 |
+
n = ch.size
|
| 20 |
+
cut = cutoff * n // 100
|
| 21 |
+
if cut == 0:
|
| 22 |
+
high, low = ch.max(), ch.min()
|
| 23 |
+
else:
|
| 24 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
| 25 |
+
low = np.argwhere(np.cumsum(hist) > cut)
|
| 26 |
+
low = 0 if low.shape[0] == 0 else low[0]
|
| 27 |
+
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
|
| 28 |
+
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
|
| 29 |
+
if high <= low:
|
| 30 |
+
table = np.arange(n_bins)
|
| 31 |
+
else:
|
| 32 |
+
scale = (n_bins - 1) / (high - low)
|
| 33 |
+
offset = np.multiply(low, -scale)
|
| 34 |
+
table = np.arange(n_bins) * scale + offset
|
| 35 |
+
table[table < 0] = 0
|
| 36 |
+
table[table > n_bins - 1] = n_bins - 1
|
| 37 |
+
table = table.clip(0, 255).astype(np.uint8)
|
| 38 |
+
return table[ch]
|
| 39 |
+
|
| 40 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
| 41 |
+
out = cv2.merge(channels)
|
| 42 |
+
return out
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def equalize_func(img):
|
| 46 |
+
"""
|
| 47 |
+
same output as PIL.ImageOps.equalize
|
| 48 |
+
PIL's implementation is different from cv2.equalize
|
| 49 |
+
"""
|
| 50 |
+
n_bins = 256
|
| 51 |
+
|
| 52 |
+
def tune_channel(ch):
|
| 53 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
| 54 |
+
non_zero_hist = hist[hist != 0].reshape(-1)
|
| 55 |
+
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
|
| 56 |
+
if step == 0:
|
| 57 |
+
return ch
|
| 58 |
+
n = np.empty_like(hist)
|
| 59 |
+
n[0] = step // 2
|
| 60 |
+
n[1:] = hist[:-1]
|
| 61 |
+
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
|
| 62 |
+
return table[ch]
|
| 63 |
+
|
| 64 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
| 65 |
+
out = cv2.merge(channels)
|
| 66 |
+
return out
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def rotate_func(img, degree, fill=(0, 0, 0)):
|
| 70 |
+
"""
|
| 71 |
+
like PIL, rotate by degree, not radians
|
| 72 |
+
"""
|
| 73 |
+
H, W = img.shape[0], img.shape[1]
|
| 74 |
+
center = W / 2, H / 2
|
| 75 |
+
M = cv2.getRotationMatrix2D(center, degree, 1)
|
| 76 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
|
| 77 |
+
return out
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def solarize_func(img, thresh=128):
|
| 81 |
+
"""
|
| 82 |
+
same output as PIL.ImageOps.posterize
|
| 83 |
+
"""
|
| 84 |
+
table = np.array([el if el < thresh else 255 - el for el in range(256)])
|
| 85 |
+
table = table.clip(0, 255).astype(np.uint8)
|
| 86 |
+
out = table[img]
|
| 87 |
+
return out
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def color_func(img, factor):
|
| 91 |
+
"""
|
| 92 |
+
same output as PIL.ImageEnhance.Color
|
| 93 |
+
"""
|
| 94 |
+
## implementation according to PIL definition, quite slow
|
| 95 |
+
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
|
| 96 |
+
# out = blend(degenerate, img, factor)
|
| 97 |
+
# M = (
|
| 98 |
+
# np.eye(3) * factor
|
| 99 |
+
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
|
| 100 |
+
# )[np.newaxis, np.newaxis, :]
|
| 101 |
+
M = np.array(
|
| 102 |
+
[[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]],
|
| 103 |
+
dtype=np.float32,
|
| 104 |
+
) * factor + np.array([[0.114], [0.587], [0.299]], dtype=np.float32)
|
| 105 |
+
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
|
| 106 |
+
return out
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def contrast_func(img, factor):
|
| 110 |
+
"""
|
| 111 |
+
same output as PIL.ImageEnhance.Contrast
|
| 112 |
+
"""
|
| 113 |
+
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
|
| 114 |
+
table = (
|
| 115 |
+
np.array([(el - mean) * factor + mean for el in range(256)])
|
| 116 |
+
.clip(0, 255)
|
| 117 |
+
.astype(np.uint8)
|
| 118 |
+
)
|
| 119 |
+
out = table[img]
|
| 120 |
+
return out
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def brightness_func(img, factor):
|
| 124 |
+
"""
|
| 125 |
+
same output as PIL.ImageEnhance.Contrast
|
| 126 |
+
"""
|
| 127 |
+
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
|
| 128 |
+
out = table[img]
|
| 129 |
+
return out
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def sharpness_func(img, factor):
|
| 133 |
+
"""
|
| 134 |
+
The differences the this result and PIL are all on the 4 boundaries, the center
|
| 135 |
+
areas are same
|
| 136 |
+
"""
|
| 137 |
+
kernel = np.ones((3, 3), dtype=np.float32)
|
| 138 |
+
kernel[1][1] = 5
|
| 139 |
+
kernel /= 13
|
| 140 |
+
degenerate = cv2.filter2D(img, -1, kernel)
|
| 141 |
+
if factor == 0.0:
|
| 142 |
+
out = degenerate
|
| 143 |
+
elif factor == 1.0:
|
| 144 |
+
out = img
|
| 145 |
+
else:
|
| 146 |
+
out = img.astype(np.float32)
|
| 147 |
+
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
|
| 148 |
+
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
|
| 149 |
+
out = out.astype(np.uint8)
|
| 150 |
+
return out
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def shear_x_func(img, factor, fill=(0, 0, 0)):
|
| 154 |
+
H, W = img.shape[0], img.shape[1]
|
| 155 |
+
M = np.array([[1, factor, 0], [0, 1, 0]], dtype=np.float32)
|
| 156 |
+
out = cv2.warpAffine(
|
| 157 |
+
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
|
| 158 |
+
).astype(np.uint8)
|
| 159 |
+
return out
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def translate_x_func(img, offset, fill=(0, 0, 0)):
|
| 163 |
+
"""
|
| 164 |
+
same output as PIL.Image.transform
|
| 165 |
+
"""
|
| 166 |
+
H, W = img.shape[0], img.shape[1]
|
| 167 |
+
M = np.array([[1, 0, -offset], [0, 1, 0]], dtype=np.float32)
|
| 168 |
+
out = cv2.warpAffine(
|
| 169 |
+
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
|
| 170 |
+
).astype(np.uint8)
|
| 171 |
+
return out
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def translate_y_func(img, offset, fill=(0, 0, 0)):
|
| 175 |
+
"""
|
| 176 |
+
same output as PIL.Image.transform
|
| 177 |
+
"""
|
| 178 |
+
H, W = img.shape[0], img.shape[1]
|
| 179 |
+
M = np.array([[1, 0, 0], [0, 1, -offset]], dtype=np.float32)
|
| 180 |
+
out = cv2.warpAffine(
|
| 181 |
+
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
|
| 182 |
+
).astype(np.uint8)
|
| 183 |
+
return out
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def posterize_func(img, bits):
|
| 187 |
+
"""
|
| 188 |
+
same output as PIL.ImageOps.posterize
|
| 189 |
+
"""
|
| 190 |
+
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
|
| 191 |
+
return out
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def shear_y_func(img, factor, fill=(0, 0, 0)):
|
| 195 |
+
H, W = img.shape[0], img.shape[1]
|
| 196 |
+
M = np.array([[1, 0, 0], [factor, 1, 0]], dtype=np.float32)
|
| 197 |
+
out = cv2.warpAffine(
|
| 198 |
+
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
|
| 199 |
+
).astype(np.uint8)
|
| 200 |
+
return out
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def cutout_func(img, pad_size, replace=(0, 0, 0)):
|
| 204 |
+
replace = np.array(replace, dtype=np.uint8)
|
| 205 |
+
H, W = img.shape[0], img.shape[1]
|
| 206 |
+
rh, rw = np.random.random(2)
|
| 207 |
+
pad_size = pad_size // 2
|
| 208 |
+
ch, cw = int(rh * H), int(rw * W)
|
| 209 |
+
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
|
| 210 |
+
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
|
| 211 |
+
out = img.copy()
|
| 212 |
+
out[x1:x2, y1:y2, :] = replace
|
| 213 |
+
return out
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
### level to args
|
| 217 |
+
def enhance_level_to_args(MAX_LEVEL):
|
| 218 |
+
def level_to_args(level):
|
| 219 |
+
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
|
| 220 |
+
|
| 221 |
+
return level_to_args
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def shear_level_to_args(MAX_LEVEL, replace_value):
|
| 225 |
+
def level_to_args(level):
|
| 226 |
+
level = (level / MAX_LEVEL) * 0.3
|
| 227 |
+
if np.random.random() > 0.5:
|
| 228 |
+
level = -level
|
| 229 |
+
return (level, replace_value)
|
| 230 |
+
|
| 231 |
+
return level_to_args
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
|
| 235 |
+
def level_to_args(level):
|
| 236 |
+
level = (level / MAX_LEVEL) * float(translate_const)
|
| 237 |
+
if np.random.random() > 0.5:
|
| 238 |
+
level = -level
|
| 239 |
+
return (level, replace_value)
|
| 240 |
+
|
| 241 |
+
return level_to_args
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
|
| 245 |
+
def level_to_args(level):
|
| 246 |
+
level = int((level / MAX_LEVEL) * cutout_const)
|
| 247 |
+
return (level, replace_value)
|
| 248 |
+
|
| 249 |
+
return level_to_args
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def solarize_level_to_args(MAX_LEVEL):
|
| 253 |
+
def level_to_args(level):
|
| 254 |
+
level = int((level / MAX_LEVEL) * 256)
|
| 255 |
+
return (level,)
|
| 256 |
+
|
| 257 |
+
return level_to_args
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def none_level_to_args(level):
|
| 261 |
+
return ()
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def posterize_level_to_args(MAX_LEVEL):
|
| 265 |
+
def level_to_args(level):
|
| 266 |
+
level = int((level / MAX_LEVEL) * 4)
|
| 267 |
+
return (level,)
|
| 268 |
+
|
| 269 |
+
return level_to_args
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def rotate_level_to_args(MAX_LEVEL, replace_value):
|
| 273 |
+
def level_to_args(level):
|
| 274 |
+
level = (level / MAX_LEVEL) * 30
|
| 275 |
+
if np.random.random() < 0.5:
|
| 276 |
+
level = -level
|
| 277 |
+
return (level, replace_value)
|
| 278 |
+
|
| 279 |
+
return level_to_args
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
func_dict = {
|
| 283 |
+
"Identity": identity_func,
|
| 284 |
+
"AutoContrast": autocontrast_func,
|
| 285 |
+
"Equalize": equalize_func,
|
| 286 |
+
"Rotate": rotate_func,
|
| 287 |
+
"Solarize": solarize_func,
|
| 288 |
+
"Color": color_func,
|
| 289 |
+
"Contrast": contrast_func,
|
| 290 |
+
"Brightness": brightness_func,
|
| 291 |
+
"Sharpness": sharpness_func,
|
| 292 |
+
"ShearX": shear_x_func,
|
| 293 |
+
"TranslateX": translate_x_func,
|
| 294 |
+
"TranslateY": translate_y_func,
|
| 295 |
+
"Posterize": posterize_func,
|
| 296 |
+
"ShearY": shear_y_func,
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
translate_const = 10
|
| 300 |
+
MAX_LEVEL = 10
|
| 301 |
+
replace_value = (128, 128, 128)
|
| 302 |
+
arg_dict = {
|
| 303 |
+
"Identity": none_level_to_args,
|
| 304 |
+
"AutoContrast": none_level_to_args,
|
| 305 |
+
"Equalize": none_level_to_args,
|
| 306 |
+
"Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
|
| 307 |
+
"Solarize": solarize_level_to_args(MAX_LEVEL),
|
| 308 |
+
"Color": enhance_level_to_args(MAX_LEVEL),
|
| 309 |
+
"Contrast": enhance_level_to_args(MAX_LEVEL),
|
| 310 |
+
"Brightness": enhance_level_to_args(MAX_LEVEL),
|
| 311 |
+
"Sharpness": enhance_level_to_args(MAX_LEVEL),
|
| 312 |
+
"ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
|
| 313 |
+
"TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
|
| 314 |
+
"TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
|
| 315 |
+
"Posterize": posterize_level_to_args(MAX_LEVEL),
|
| 316 |
+
"ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class RandomAugment(object):
|
| 321 |
+
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
|
| 322 |
+
self.N = N
|
| 323 |
+
self.M = M
|
| 324 |
+
self.isPIL = isPIL
|
| 325 |
+
if augs:
|
| 326 |
+
self.augs = augs
|
| 327 |
+
else:
|
| 328 |
+
self.augs = list(arg_dict.keys())
|
| 329 |
+
|
| 330 |
+
def get_random_ops(self):
|
| 331 |
+
sampled_ops = np.random.choice(self.augs, self.N)
|
| 332 |
+
return [(op, 0.5, self.M) for op in sampled_ops]
|
| 333 |
+
|
| 334 |
+
def __call__(self, img):
|
| 335 |
+
if self.isPIL:
|
| 336 |
+
img = np.array(img)
|
| 337 |
+
ops = self.get_random_ops()
|
| 338 |
+
for name, prob, level in ops:
|
| 339 |
+
if np.random.random() > prob:
|
| 340 |
+
continue
|
| 341 |
+
args = arg_dict[name](level)
|
| 342 |
+
img = func_dict[name](img, *args)
|
| 343 |
+
return img
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
if __name__ == "__main__":
|
| 347 |
+
a = RandomAugment()
|
| 348 |
+
img = np.random.randn(32, 32, 3)
|
| 349 |
+
a(img)
|
GOAL_github/utils/transforms.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchvision import transforms
|
| 2 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 3 |
+
from torchvision.transforms import Compose, CenterCrop, ToTensor, Normalize, Resize
|
| 4 |
+
from utils.randaugment import RandomAugment
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision.transforms.functional as FT
|
| 7 |
+
import math
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
normalize = transforms.Normalize(
|
| 11 |
+
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
class transform_train:
|
| 15 |
+
def __init__(self, image_size=384, min_scale=0.5):
|
| 16 |
+
self.transform = transforms.Compose(
|
| 17 |
+
[
|
| 18 |
+
transforms.RandomResizedCrop(
|
| 19 |
+
image_size,
|
| 20 |
+
scale=(min_scale, 1.0),
|
| 21 |
+
interpolation=InterpolationMode.BICUBIC,
|
| 22 |
+
),
|
| 23 |
+
transforms.RandomHorizontalFlip(),
|
| 24 |
+
transforms.ToTensor(),
|
| 25 |
+
normalize,
|
| 26 |
+
]
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def __call__(self, img):
|
| 30 |
+
return self.transform(img)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# class transform_train:
|
| 34 |
+
# def __init__(self, image_size=384, min_scale=0.5):
|
| 35 |
+
# self.transform = transforms.Compose(
|
| 36 |
+
# [
|
| 37 |
+
# transforms.RandomResizedCrop(
|
| 38 |
+
# image_size,
|
| 39 |
+
# scale=(min_scale, 1.0),
|
| 40 |
+
# interpolation=InterpolationMode.BICUBIC,
|
| 41 |
+
# ),
|
| 42 |
+
# transforms.RandomHorizontalFlip(),
|
| 43 |
+
# RandomAugment(
|
| 44 |
+
# 2,
|
| 45 |
+
# 5,
|
| 46 |
+
# isPIL=True,
|
| 47 |
+
# augs=[
|
| 48 |
+
# "Identity",
|
| 49 |
+
# "AutoContrast",
|
| 50 |
+
# "Brightness",
|
| 51 |
+
# "Sharpness",
|
| 52 |
+
# "Equalize",
|
| 53 |
+
# "ShearX",
|
| 54 |
+
# "ShearY",
|
| 55 |
+
# "TranslateX",
|
| 56 |
+
# "TranslateY",
|
| 57 |
+
# "Rotate",
|
| 58 |
+
# ],
|
| 59 |
+
# ),
|
| 60 |
+
# transforms.ToTensor(),
|
| 61 |
+
# normalize,
|
| 62 |
+
# ]
|
| 63 |
+
# )
|
| 64 |
+
|
| 65 |
+
# def __call__(self, img):
|
| 66 |
+
# return self.transform(img)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class transform_test(transforms.Compose):
|
| 70 |
+
def __init__(self, image_size=384):
|
| 71 |
+
self.transform = transforms.Compose(
|
| 72 |
+
[
|
| 73 |
+
transforms.Resize(
|
| 74 |
+
(image_size, image_size),
|
| 75 |
+
interpolation=InterpolationMode.BICUBIC,
|
| 76 |
+
),
|
| 77 |
+
transforms.ToTensor(),
|
| 78 |
+
normalize,
|
| 79 |
+
]
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def __call__(self, img):
|
| 83 |
+
return self.transform(img)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class TargetPad:
|
| 88 |
+
"""
|
| 89 |
+
If an image aspect ratio is above a target ratio, pad the image to match such target ratio.
|
| 90 |
+
For more details see Baldrati et al. 'Effective conditioned and composed image retrieval combining clip-based features.' Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (2022).
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(self, target_ratio: float, size: int):
|
| 94 |
+
"""
|
| 95 |
+
:param target_ratio: target ratio
|
| 96 |
+
:param size: preprocessing output dimension
|
| 97 |
+
"""
|
| 98 |
+
self.size = size
|
| 99 |
+
self.target_ratio = target_ratio
|
| 100 |
+
|
| 101 |
+
def __call__(self, image):
|
| 102 |
+
w, h = image.size
|
| 103 |
+
actual_ratio = max(w, h) / min(w, h)
|
| 104 |
+
if actual_ratio < self.target_ratio: # check if the ratio is above or below the target ratio
|
| 105 |
+
return image
|
| 106 |
+
scaled_max_wh = max(w, h) / self.target_ratio # rescale the pad to match the target ratio
|
| 107 |
+
hp = max(int((scaled_max_wh - w) / 2), 0)
|
| 108 |
+
vp = max(int((scaled_max_wh - h) / 2), 0)
|
| 109 |
+
padding = [hp, vp, hp, vp]
|
| 110 |
+
return FT.pad(image, padding, 0, 'constant')
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _convert_image_to_rgb(image):
|
| 114 |
+
return image.convert("RGB")
|
| 115 |
+
|
| 116 |
+
def targetpad_transform(target_ratio: float, dim: int) -> torch.Tensor:
|
| 117 |
+
"""
|
| 118 |
+
CLIP-like preprocessing transform computed after using TargetPad pad
|
| 119 |
+
:param target_ratio: target ratio for TargetPad
|
| 120 |
+
:param dim: image output dimension
|
| 121 |
+
:return: CLIP-like torchvision Compose transform
|
| 122 |
+
"""
|
| 123 |
+
return Compose([
|
| 124 |
+
TargetPad(target_ratio, dim),
|
| 125 |
+
Resize(dim, interpolation=InterpolationMode.BICUBIC),
|
| 126 |
+
CenterCrop(dim),
|
| 127 |
+
_convert_image_to_rgb,
|
| 128 |
+
ToTensor(),
|
| 129 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
| 130 |
+
])
|
GOAL_github/visualization/visualization_attentionmap_longtestset.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import lightning as L
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import transformers
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from torch.utils.data import Dataset
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from utils.func import *
|
| 16 |
+
from utils.transforms import *
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
import cv2
|
| 19 |
+
|
| 20 |
+
# Hyperparameters
|
| 21 |
+
micro_batch_size = 1
|
| 22 |
+
devices = 1
|
| 23 |
+
num_workers = 1
|
| 24 |
+
|
| 25 |
+
class ImageAttentionLoader(Dataset):
|
| 26 |
+
def __init__(self, image_path, processor):
|
| 27 |
+
self.image_path = image_path
|
| 28 |
+
self.processor = processor
|
| 29 |
+
|
| 30 |
+
def __len__(self):
|
| 31 |
+
return 1
|
| 32 |
+
|
| 33 |
+
def _load_image(self):
|
| 34 |
+
return Image.open(self.image_path).convert("RGB")
|
| 35 |
+
|
| 36 |
+
def __getitem__(self, idx):
|
| 37 |
+
image = self._load_image()
|
| 38 |
+
self.original_size = image.size
|
| 39 |
+
# Process image without return_tensors
|
| 40 |
+
data = self.processor(images=image)
|
| 41 |
+
# Convert to tensor manually
|
| 42 |
+
pixel_values = torch.tensor(data['pixel_values'])
|
| 43 |
+
# Ensure correct shape [C, H, W]
|
| 44 |
+
if len(pixel_values.shape) > 3:
|
| 45 |
+
pixel_values = pixel_values.squeeze()
|
| 46 |
+
return pixel_values, np.array(image)
|
| 47 |
+
|
| 48 |
+
def get_attention_maps(model, image_input):
|
| 49 |
+
print("Input shape before processing:", image_input.shape)
|
| 50 |
+
|
| 51 |
+
# Ensure correct shape and device
|
| 52 |
+
if len(image_input.shape) == 3:
|
| 53 |
+
image_input = image_input.unsqueeze(0) # Add batch dimension
|
| 54 |
+
|
| 55 |
+
image_input = image_input.to(model.device)
|
| 56 |
+
print("Input shape after processing:", image_input.shape)
|
| 57 |
+
|
| 58 |
+
outputs = model.vision_model(pixel_values=image_input, output_attentions=True)
|
| 59 |
+
|
| 60 |
+
# 마지막 레이어의 어텐션 맵 가져오기
|
| 61 |
+
attention_maps = outputs.attentions[-1]
|
| 62 |
+
print("Attention maps shape:", attention_maps.shape)
|
| 63 |
+
|
| 64 |
+
# 어텐션 맵의 크기 조정 (헤드 평균)
|
| 65 |
+
attention_maps = attention_maps.mean(dim=1)
|
| 66 |
+
|
| 67 |
+
# CLS 토큰을 제외한 패치들의 어텐션만 사용
|
| 68 |
+
attention_maps = attention_maps[:, 1:, 1:]
|
| 69 |
+
|
| 70 |
+
# CPU로 이동하고 float32로 변환
|
| 71 |
+
attention_maps = attention_maps.squeeze().cpu().float()
|
| 72 |
+
print("Final attention maps shape:", attention_maps.shape)
|
| 73 |
+
|
| 74 |
+
return attention_maps
|
| 75 |
+
|
| 76 |
+
def visualize_attention(image, attention_maps, output_path, model_name):
|
| 77 |
+
# 이미지 패치 크기 계산
|
| 78 |
+
if 'patch14' in model_name:
|
| 79 |
+
patch_size = 14
|
| 80 |
+
elif 'patch32' in model_name:
|
| 81 |
+
patch_size = 32
|
| 82 |
+
else:
|
| 83 |
+
patch_size = 16
|
| 84 |
+
|
| 85 |
+
# 어텐션 맵 크기 조정
|
| 86 |
+
h = int(np.sqrt(attention_maps.shape[0]))
|
| 87 |
+
attention_maps = attention_maps.reshape(h, h)
|
| 88 |
+
|
| 89 |
+
# 어텐션 맵을 원본 이미지 크기로 업샘플링
|
| 90 |
+
attention_maps = F.interpolate(
|
| 91 |
+
attention_maps.unsqueeze(0).unsqueeze(0),
|
| 92 |
+
size=image.shape[:2],
|
| 93 |
+
mode='bicubic'
|
| 94 |
+
).squeeze().numpy()
|
| 95 |
+
|
| 96 |
+
# 정규화
|
| 97 |
+
attention_maps = (attention_maps - attention_maps.min()) / (attention_maps.max() - attention_maps.min())
|
| 98 |
+
|
| 99 |
+
# 시각화
|
| 100 |
+
plt.figure(figsize=(15, 5))
|
| 101 |
+
|
| 102 |
+
# 원본 이미지
|
| 103 |
+
plt.subplot(1, 3, 1)
|
| 104 |
+
plt.imshow(image)
|
| 105 |
+
plt.title('Original Image')
|
| 106 |
+
plt.axis('off')
|
| 107 |
+
|
| 108 |
+
# 어텐션 맵
|
| 109 |
+
plt.subplot(1, 3, 2)
|
| 110 |
+
plt.imshow(attention_maps, cmap='jet')
|
| 111 |
+
plt.title('Self-Attention Map')
|
| 112 |
+
plt.axis('off')
|
| 113 |
+
|
| 114 |
+
# 오버레이
|
| 115 |
+
plt.subplot(1, 3, 3)
|
| 116 |
+
plt.imshow(image)
|
| 117 |
+
plt.imshow(attention_maps, cmap='jet', alpha=0.5)
|
| 118 |
+
plt.title('Overlay')
|
| 119 |
+
plt.axis('off')
|
| 120 |
+
|
| 121 |
+
plt.tight_layout()
|
| 122 |
+
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
| 123 |
+
plt.close()
|
| 124 |
+
|
| 125 |
+
def main(args):
|
| 126 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
|
| 127 |
+
fabric = L.Fabric(
|
| 128 |
+
accelerator="cuda",
|
| 129 |
+
devices=devices,
|
| 130 |
+
precision="bf16-mixed"
|
| 131 |
+
)
|
| 132 |
+
fabric.launch()
|
| 133 |
+
fabric.seed_everything(1337 + fabric.global_rank)
|
| 134 |
+
|
| 135 |
+
if args.model=='L-336':
|
| 136 |
+
args.model = 'openai/clip-vit-large-patch14-336'
|
| 137 |
+
elif args.model=='L':
|
| 138 |
+
args.model = 'openai/clip-vit-large-patch14'
|
| 139 |
+
elif args.model=='B':
|
| 140 |
+
args.model = 'openai/clip-vit-base-patch32'
|
| 141 |
+
elif args.model=='G':
|
| 142 |
+
args.model = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
|
| 143 |
+
|
| 144 |
+
with fabric.device:
|
| 145 |
+
processor = transformers.AutoProcessor.from_pretrained(args.model)
|
| 146 |
+
model = transformers.AutoModel.from_pretrained(args.model, output_attentions=True).bfloat16()
|
| 147 |
+
longclip_pos_embeddings(model, args.new_max_token)
|
| 148 |
+
if args.ckpt:
|
| 149 |
+
model.load_state_dict(torch.load(args.ckpt), strict=False)
|
| 150 |
+
|
| 151 |
+
# 데이터 로더 생성
|
| 152 |
+
dataset = ImageAttentionLoader(args.image_path, processor)
|
| 153 |
+
dataloader = DataLoader(dataset, batch_size=micro_batch_size, shuffle=False)
|
| 154 |
+
|
| 155 |
+
# 모델 평가 모드로 설정
|
| 156 |
+
model.eval()
|
| 157 |
+
model = model.to(fabric.device)
|
| 158 |
+
|
| 159 |
+
# 어텐션 맵 생성 및 시각화
|
| 160 |
+
with torch.no_grad():
|
| 161 |
+
for batch in dataloader:
|
| 162 |
+
image_input, original_image = batch
|
| 163 |
+
print("Original input shape:", image_input.shape)
|
| 164 |
+
attention_maps = get_attention_maps(model, image_input)
|
| 165 |
+
|
| 166 |
+
# 모든 패치의 평균 어텐션 점수 계산
|
| 167 |
+
avg_attention = attention_maps.mean(dim=0)
|
| 168 |
+
|
| 169 |
+
visualize_attention(
|
| 170 |
+
original_image[0],
|
| 171 |
+
avg_attention,
|
| 172 |
+
args.output_path,
|
| 173 |
+
args.model
|
| 174 |
+
)
|
| 175 |
+
break
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
torch.set_float32_matmul_precision("high")
|
| 179 |
+
|
| 180 |
+
parser = argparse.ArgumentParser()
|
| 181 |
+
parser.add_argument("--image_path", type=str, required=True, help="Path to input image")
|
| 182 |
+
parser.add_argument("--output_path", type=str, required=True, help="Path to save visualization")
|
| 183 |
+
parser.add_argument('--new_max_token', default=248, type=int)
|
| 184 |
+
parser.add_argument("--model", type=str, default='L')
|
| 185 |
+
parser.add_argument("--ckpt", type=str, default='')
|
| 186 |
+
args = parser.parse_args()
|
| 187 |
+
|
| 188 |
+
main(args)
|
GOAL_github/visualization/visualization_retreival.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import lightning as L
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import transformers
|
| 11 |
+
from torch.utils.data import DataLoader, Dataset
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from utils.func import *
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
import cv2
|
| 17 |
+
import textwrap
|
| 18 |
+
|
| 19 |
+
def collate_fn(batch):
|
| 20 |
+
# 배치에서 최대 길이 찾기
|
| 21 |
+
max_text_length = max(len(item['input_ids']) for item in batch)
|
| 22 |
+
|
| 23 |
+
# 배치의 모든 항목을 담을 리스트
|
| 24 |
+
pixel_values = []
|
| 25 |
+
input_ids = []
|
| 26 |
+
attention_masks = []
|
| 27 |
+
image_paths = []
|
| 28 |
+
captions = []
|
| 29 |
+
|
| 30 |
+
for item in batch:
|
| 31 |
+
# 이미지 처리
|
| 32 |
+
pixel_values.append(item['pixel_values'])
|
| 33 |
+
|
| 34 |
+
# 텍스트 패딩
|
| 35 |
+
curr_len = len(item['input_ids'])
|
| 36 |
+
padding_len = max_text_length - curr_len
|
| 37 |
+
|
| 38 |
+
padded_ids = torch.cat([
|
| 39 |
+
item['input_ids'],
|
| 40 |
+
torch.zeros(padding_len, dtype=torch.long)
|
| 41 |
+
])
|
| 42 |
+
padded_mask = torch.cat([
|
| 43 |
+
item['attention_mask'],
|
| 44 |
+
torch.zeros(padding_len, dtype=torch.long)
|
| 45 |
+
])
|
| 46 |
+
|
| 47 |
+
input_ids.append(padded_ids)
|
| 48 |
+
attention_masks.append(padded_mask)
|
| 49 |
+
image_paths.append(item['image_path'])
|
| 50 |
+
captions.append(item['caption'])
|
| 51 |
+
|
| 52 |
+
# 배치로 만들기
|
| 53 |
+
return {
|
| 54 |
+
'pixel_values': torch.stack(pixel_values),
|
| 55 |
+
'input_ids': torch.stack(input_ids),
|
| 56 |
+
'attention_mask': torch.stack(attention_masks),
|
| 57 |
+
'image_path': image_paths,
|
| 58 |
+
'caption': captions
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
class JsonGalleryDataset(Dataset):
|
| 62 |
+
def __init__(self, json_path, processor, max_length=248):
|
| 63 |
+
with open(json_path, 'r', encoding='utf-8') as f:
|
| 64 |
+
self.data = json.load(f)
|
| 65 |
+
self.processor = processor
|
| 66 |
+
self.max_length = max_length
|
| 67 |
+
|
| 68 |
+
def __len__(self):
|
| 69 |
+
return len(self.data)
|
| 70 |
+
|
| 71 |
+
def __getitem__(self, idx):
|
| 72 |
+
item = self.data[idx]
|
| 73 |
+
image_path = item['filename']
|
| 74 |
+
caption = item['caption']
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
image = Image.open(image_path).convert("RGB")
|
| 78 |
+
# 이미지와 텍스트 모두 처리
|
| 79 |
+
processed_image = self.processor(images=image, return_tensors="pt")
|
| 80 |
+
processed_text = self.processor(
|
| 81 |
+
text=caption,
|
| 82 |
+
return_tensors="pt",
|
| 83 |
+
padding='max_length',
|
| 84 |
+
max_length=self.max_length,
|
| 85 |
+
truncation=True
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
return {
|
| 89 |
+
'pixel_values': processed_image['pixel_values'].squeeze(0),
|
| 90 |
+
'input_ids': processed_text['input_ids'].squeeze(0),
|
| 91 |
+
'attention_mask': processed_text['attention_mask'].squeeze(0),
|
| 92 |
+
'image_path': image_path,
|
| 93 |
+
'caption': caption
|
| 94 |
+
}
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print(f"Error loading image {image_path}: {e}")
|
| 97 |
+
return self.__getitem__(0)
|
| 98 |
+
|
| 99 |
+
def compute_similarities(model, query_features, gallery_features):
|
| 100 |
+
query_features = F.normalize(query_features, dim=-1)
|
| 101 |
+
gallery_features = F.normalize(gallery_features, dim=-1)
|
| 102 |
+
similarities = torch.mm(query_features, gallery_features.t())
|
| 103 |
+
return similarities
|
| 104 |
+
|
| 105 |
+
def process_query(model, processor, query, device, is_image=False, max_length=248):
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
if is_image:
|
| 108 |
+
image = Image.open(query).convert("RGB")
|
| 109 |
+
inputs = processor(images=image, return_tensors="pt")
|
| 110 |
+
pixel_values = inputs['pixel_values'].to(device)
|
| 111 |
+
features = model.get_image_features(pixel_values)
|
| 112 |
+
else:
|
| 113 |
+
inputs = processor(
|
| 114 |
+
text=query,
|
| 115 |
+
return_tensors="pt",
|
| 116 |
+
padding='max_length',
|
| 117 |
+
max_length=max_length,
|
| 118 |
+
truncation=True
|
| 119 |
+
)
|
| 120 |
+
input_ids = inputs['input_ids'].to(device)
|
| 121 |
+
attention_mask = inputs['attention_mask'].to(device)
|
| 122 |
+
features = model.get_text_features(input_ids, attention_mask)
|
| 123 |
+
return features
|
| 124 |
+
|
| 125 |
+
def visualize_results(query, results, output_path, is_image_query=False):
|
| 126 |
+
# 하나의 결과만 사용
|
| 127 |
+
n_results = 1
|
| 128 |
+
plt.rcParams['figure.facecolor'] = 'white'
|
| 129 |
+
plt.rcParams['axes.facecolor'] = 'white'
|
| 130 |
+
fig, axes = plt.subplots(1, 1 + n_results, figsize=(10, 4)) # figure 크기 조정
|
| 131 |
+
|
| 132 |
+
if is_image_query:
|
| 133 |
+
# 이미지 쿼리인 경우
|
| 134 |
+
query_img = Image.open(query)
|
| 135 |
+
axes[0].imshow(query_img)
|
| 136 |
+
axes[0].set_title("Image query", fontsize=12, pad=10)
|
| 137 |
+
|
| 138 |
+
# 결과 표시 (캡션만)
|
| 139 |
+
img_path, similarity, caption = results[0] # 첫 번째 결과만 사용
|
| 140 |
+
axes[1].set_facecolor('white')
|
| 141 |
+
wrapped_text = textwrap.fill(caption, width=40)
|
| 142 |
+
|
| 143 |
+
# 텍스트 박스 없이 텍스트만 표시
|
| 144 |
+
axes[1].text(0.5, 0.5, wrapped_text,
|
| 145 |
+
horizontalalignment='center',
|
| 146 |
+
verticalalignment='center',
|
| 147 |
+
transform=axes[1].transAxes,
|
| 148 |
+
wrap=True,
|
| 149 |
+
fontsize=10)
|
| 150 |
+
# axes[1].set_title("Recall@1", fontsize=12, pad=10)
|
| 151 |
+
|
| 152 |
+
else:
|
| 153 |
+
# 텍스트 쿼리인 경우
|
| 154 |
+
axes[0].set_facecolor('white')
|
| 155 |
+
wrapped_text = textwrap.fill(query, width=40)
|
| 156 |
+
|
| 157 |
+
# 텍스트 박스 없이 텍스트만 표시
|
| 158 |
+
axes[0].text(0.5, 0.5, wrapped_text,
|
| 159 |
+
horizontalalignment='center',
|
| 160 |
+
verticalalignment='center',
|
| 161 |
+
transform=axes[0].transAxes,
|
| 162 |
+
wrap=True,
|
| 163 |
+
fontsize=10)
|
| 164 |
+
axes[0].set_title("Text query", fontsize=12, pad=10)
|
| 165 |
+
|
| 166 |
+
# 결과 표시 (이미지만)
|
| 167 |
+
img_path, similarity, caption = results[0] # 첫 번째 결과만 사용
|
| 168 |
+
img = Image.open(img_path)
|
| 169 |
+
axes[1].imshow(img)
|
| 170 |
+
# axes[1].set_title("Recall@1", fontsize=12, pad=10)
|
| 171 |
+
|
| 172 |
+
# 모든 subplot의 공통 스타일 설정
|
| 173 |
+
for ax in axes:
|
| 174 |
+
ax.axis('off')
|
| 175 |
+
ax.set_xticks([])
|
| 176 |
+
ax.set_yticks([])
|
| 177 |
+
ax.spines['top'].set_visible(True)
|
| 178 |
+
ax.spines['right'].set_visible(True)
|
| 179 |
+
ax.spines['bottom'].set_visible(True)
|
| 180 |
+
ax.spines['left'].set_visible(True)
|
| 181 |
+
|
| 182 |
+
plt.tight_layout()
|
| 183 |
+
plt.subplots_adjust(wspace=0)
|
| 184 |
+
|
| 185 |
+
plt.savefig(output_path,
|
| 186 |
+
dpi=300,
|
| 187 |
+
bbox_inches='tight',
|
| 188 |
+
facecolor='white',
|
| 189 |
+
edgecolor='none',
|
| 190 |
+
pad_inches=0.2)
|
| 191 |
+
plt.close()
|
| 192 |
+
|
| 193 |
+
def main(args):
|
| 194 |
+
fabric = L.Fabric(
|
| 195 |
+
accelerator="cuda",
|
| 196 |
+
devices=1,
|
| 197 |
+
precision="bf16-mixed"
|
| 198 |
+
)
|
| 199 |
+
fabric.launch()
|
| 200 |
+
fabric.seed_everything(1337 + fabric.global_rank)
|
| 201 |
+
|
| 202 |
+
# 모델 설정
|
| 203 |
+
if args.model == 'L-336':
|
| 204 |
+
model_name = 'openai/clip-vit-large-patch14-336'
|
| 205 |
+
elif args.model == 'L':
|
| 206 |
+
model_name = 'openai/clip-vit-large-patch14'
|
| 207 |
+
elif args.model == 'B':
|
| 208 |
+
model_name = 'openai/clip-vit-base-patch32'
|
| 209 |
+
elif args.model == 'G':
|
| 210 |
+
model_name = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
|
| 211 |
+
|
| 212 |
+
# 모델과 프로세서 로드
|
| 213 |
+
processor = transformers.CLIPProcessor.from_pretrained(model_name)
|
| 214 |
+
model = transformers.CLIPModel.from_pretrained(model_name).bfloat16()
|
| 215 |
+
|
| 216 |
+
# 먼저 position embedding 확장
|
| 217 |
+
longclip_pos_embeddings(model, args.new_max_token)
|
| 218 |
+
|
| 219 |
+
# 그 다음 가중치 로드
|
| 220 |
+
if args.ckpt:
|
| 221 |
+
state_dict = torch.load(args.ckpt, weights_only=True)
|
| 222 |
+
model.load_state_dict(state_dict, strict=False)
|
| 223 |
+
|
| 224 |
+
model = model.to(fabric.device)
|
| 225 |
+
model.eval()
|
| 226 |
+
|
| 227 |
+
# 갤러리 데이터셋 준비
|
| 228 |
+
dataset = JsonGalleryDataset(args.gallery_json, processor, args.new_max_token)
|
| 229 |
+
dataloader = DataLoader(
|
| 230 |
+
dataset,
|
| 231 |
+
batch_size=32,
|
| 232 |
+
shuffle=False,
|
| 233 |
+
num_workers=4,
|
| 234 |
+
collate_fn=collate_fn
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# 쿼리가 이미지인지 텍스트인지 확인
|
| 238 |
+
is_image_query = args.query.endswith(('.jpg', '.jpeg', '.png'))
|
| 239 |
+
|
| 240 |
+
# 갤러리 특징 추출
|
| 241 |
+
gallery_features = []
|
| 242 |
+
gallery_items = []
|
| 243 |
+
|
| 244 |
+
with torch.no_grad():
|
| 245 |
+
for batch in dataloader:
|
| 246 |
+
if is_image_query:
|
| 247 |
+
# 이미지 쿼리의 경우 텍스트 특징 추출
|
| 248 |
+
input_ids = batch['input_ids'].to(fabric.device)
|
| 249 |
+
attention_mask = batch['attention_mask'].to(fabric.device)
|
| 250 |
+
features = model.get_text_features(input_ids, attention_mask)
|
| 251 |
+
else:
|
| 252 |
+
# 텍스트 쿼리의 경우 이미지 특징 추출
|
| 253 |
+
pixel_values = batch['pixel_values'].to(fabric.device)
|
| 254 |
+
features = model.get_image_features(pixel_values)
|
| 255 |
+
|
| 256 |
+
gallery_features.append(features)
|
| 257 |
+
|
| 258 |
+
for path, caption in zip(batch['image_path'], batch['caption']):
|
| 259 |
+
gallery_items.append({
|
| 260 |
+
'image_path': path,
|
| 261 |
+
'caption': caption
|
| 262 |
+
})
|
| 263 |
+
|
| 264 |
+
gallery_features = torch.cat(gallery_features, dim=0)
|
| 265 |
+
|
| 266 |
+
# 쿼리 특징 추출
|
| 267 |
+
query_features = process_query(model, processor, args.query, fabric.device, is_image_query, args.new_max_token)
|
| 268 |
+
|
| 269 |
+
# 유사도 계산 및 상위 1개 결과만 가져오기
|
| 270 |
+
similarities = compute_similarities(model, query_features, gallery_features)
|
| 271 |
+
top_k = 1 # 하나의 결과만 가져오도록 변경
|
| 272 |
+
top_k_similarities, top_k_indices = torch.topk(similarities[0].float(), k=top_k)
|
| 273 |
+
|
| 274 |
+
# 결과 정리
|
| 275 |
+
results = []
|
| 276 |
+
top_k_similarities_np = top_k_similarities.cpu().numpy()
|
| 277 |
+
top_k_indices_np = top_k_indices.cpu().numpy()
|
| 278 |
+
|
| 279 |
+
for sim, idx in zip(top_k_similarities_np, top_k_indices_np):
|
| 280 |
+
item = gallery_items[idx]
|
| 281 |
+
results.append((
|
| 282 |
+
item['image_path'],
|
| 283 |
+
sim,
|
| 284 |
+
item['caption']
|
| 285 |
+
))
|
| 286 |
+
|
| 287 |
+
# 결과 시각화
|
| 288 |
+
visualize_results(
|
| 289 |
+
args.query,
|
| 290 |
+
results,
|
| 291 |
+
args.output_path,
|
| 292 |
+
is_image_query
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
if __name__ == "__main__":
|
| 296 |
+
torch.set_float32_matmul_precision("high")
|
| 297 |
+
|
| 298 |
+
parser = argparse.ArgumentParser()
|
| 299 |
+
parser.add_argument("--query", type=str, required=True,
|
| 300 |
+
help="Path to query image or text query")
|
| 301 |
+
parser.add_argument("--gallery_json", type=str, required=True,
|
| 302 |
+
help="Path to JSON file containing gallery information")
|
| 303 |
+
parser.add_argument("--output_path", type=str, required=True,
|
| 304 |
+
help="Path to save visualization")
|
| 305 |
+
parser.add_argument("--model", type=str, default='L',
|
| 306 |
+
choices=['L-336', 'L', 'B', 'G'])
|
| 307 |
+
parser.add_argument("--ckpt", type=str, default='',
|
| 308 |
+
help="Path to custom checkpoint")
|
| 309 |
+
parser.add_argument("--new_max_token", type=int, default=248,
|
| 310 |
+
help="Maximum number of text tokens")
|
| 311 |
+
|
| 312 |
+
args = parser.parse_args()
|
| 313 |
+
main(args)
|