GOAL / GOAL_github /visualization /visualization_retreival.py
qkenr0804's picture
Upload 29 files
e90e75c verified
import json
import argparse
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import sys
from pathlib import Path
import lightning as L
import numpy as np
import torch
import transformers
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torch.nn.functional as F
from utils.func import *
import matplotlib.pyplot as plt
import cv2
import textwrap
def collate_fn(batch):
# 배치에서 최대 길이 찾기
max_text_length = max(len(item['input_ids']) for item in batch)
# 배치의 모든 항목을 담을 리스트
pixel_values = []
input_ids = []
attention_masks = []
image_paths = []
captions = []
for item in batch:
# 이미지 처리
pixel_values.append(item['pixel_values'])
# 텍스트 패딩
curr_len = len(item['input_ids'])
padding_len = max_text_length - curr_len
padded_ids = torch.cat([
item['input_ids'],
torch.zeros(padding_len, dtype=torch.long)
])
padded_mask = torch.cat([
item['attention_mask'],
torch.zeros(padding_len, dtype=torch.long)
])
input_ids.append(padded_ids)
attention_masks.append(padded_mask)
image_paths.append(item['image_path'])
captions.append(item['caption'])
# 배치로 만들기
return {
'pixel_values': torch.stack(pixel_values),
'input_ids': torch.stack(input_ids),
'attention_mask': torch.stack(attention_masks),
'image_path': image_paths,
'caption': captions
}
class JsonGalleryDataset(Dataset):
def __init__(self, json_path, processor, max_length=248):
with open(json_path, 'r', encoding='utf-8') as f:
self.data = json.load(f)
self.processor = processor
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
image_path = item['filename']
caption = item['caption']
try:
image = Image.open(image_path).convert("RGB")
# 이미지와 텍스트 모두 처리
processed_image = self.processor(images=image, return_tensors="pt")
processed_text = self.processor(
text=caption,
return_tensors="pt",
padding='max_length',
max_length=self.max_length,
truncation=True
)
return {
'pixel_values': processed_image['pixel_values'].squeeze(0),
'input_ids': processed_text['input_ids'].squeeze(0),
'attention_mask': processed_text['attention_mask'].squeeze(0),
'image_path': image_path,
'caption': caption
}
except Exception as e:
print(f"Error loading image {image_path}: {e}")
return self.__getitem__(0)
def compute_similarities(model, query_features, gallery_features):
query_features = F.normalize(query_features, dim=-1)
gallery_features = F.normalize(gallery_features, dim=-1)
similarities = torch.mm(query_features, gallery_features.t())
return similarities
def process_query(model, processor, query, device, is_image=False, max_length=248):
with torch.no_grad():
if is_image:
image = Image.open(query).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
pixel_values = inputs['pixel_values'].to(device)
features = model.get_image_features(pixel_values)
else:
inputs = processor(
text=query,
return_tensors="pt",
padding='max_length',
max_length=max_length,
truncation=True
)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
features = model.get_text_features(input_ids, attention_mask)
return features
def visualize_results(query, results, output_path, is_image_query=False):
# 하나의 결과만 사용
n_results = 1
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.facecolor'] = 'white'
fig, axes = plt.subplots(1, 1 + n_results, figsize=(10, 4)) # figure 크기 조정
if is_image_query:
# 이미지 쿼리인 경우
query_img = Image.open(query)
axes[0].imshow(query_img)
axes[0].set_title("Image query", fontsize=12, pad=10)
# 결과 표시 (캡션만)
img_path, similarity, caption = results[0] # 첫 번째 결과만 사용
axes[1].set_facecolor('white')
wrapped_text = textwrap.fill(caption, width=40)
# 텍스트 박스 없이 텍스트만 표시
axes[1].text(0.5, 0.5, wrapped_text,
horizontalalignment='center',
verticalalignment='center',
transform=axes[1].transAxes,
wrap=True,
fontsize=10)
# axes[1].set_title("Recall@1", fontsize=12, pad=10)
else:
# 텍스트 쿼리인 경우
axes[0].set_facecolor('white')
wrapped_text = textwrap.fill(query, width=40)
# 텍스트 박스 없이 텍스트만 표시
axes[0].text(0.5, 0.5, wrapped_text,
horizontalalignment='center',
verticalalignment='center',
transform=axes[0].transAxes,
wrap=True,
fontsize=10)
axes[0].set_title("Text query", fontsize=12, pad=10)
# 결과 표시 (이미지만)
img_path, similarity, caption = results[0] # 첫 번째 결과만 사용
img = Image.open(img_path)
axes[1].imshow(img)
# axes[1].set_title("Recall@1", fontsize=12, pad=10)
# 모든 subplot의 공통 스타일 설정
for ax in axes:
ax.axis('off')
ax.set_xticks([])
ax.set_yticks([])
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.spines['left'].set_visible(True)
plt.tight_layout()
plt.subplots_adjust(wspace=0)
plt.savefig(output_path,
dpi=300,
bbox_inches='tight',
facecolor='white',
edgecolor='none',
pad_inches=0.2)
plt.close()
def main(args):
fabric = L.Fabric(
accelerator="cuda",
devices=1,
precision="bf16-mixed"
)
fabric.launch()
fabric.seed_everything(1337 + fabric.global_rank)
# 모델 설정
if args.model == 'L-336':
model_name = 'openai/clip-vit-large-patch14-336'
elif args.model == 'L':
model_name = 'openai/clip-vit-large-patch14'
elif args.model == 'B':
model_name = 'openai/clip-vit-base-patch32'
elif args.model == 'G':
model_name = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
# 모델과 프로세서 로드
processor = transformers.CLIPProcessor.from_pretrained(model_name)
model = transformers.CLIPModel.from_pretrained(model_name).bfloat16()
# 먼저 position embedding 확장
longclip_pos_embeddings(model, args.new_max_token)
# 그 다음 가중치 로드
if args.ckpt:
state_dict = torch.load(args.ckpt, weights_only=True)
model.load_state_dict(state_dict, strict=False)
model = model.to(fabric.device)
model.eval()
# 갤러리 데이터셋 준비
dataset = JsonGalleryDataset(args.gallery_json, processor, args.new_max_token)
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=False,
num_workers=4,
collate_fn=collate_fn
)
# 쿼리가 이미지인지 텍스트인지 확인
is_image_query = args.query.endswith(('.jpg', '.jpeg', '.png'))
# 갤러리 특징 추출
gallery_features = []
gallery_items = []
with torch.no_grad():
for batch in dataloader:
if is_image_query:
# 이미지 쿼리의 경우 텍스트 특징 추출
input_ids = batch['input_ids'].to(fabric.device)
attention_mask = batch['attention_mask'].to(fabric.device)
features = model.get_text_features(input_ids, attention_mask)
else:
# 텍스트 쿼리의 경우 이미지 특징 추출
pixel_values = batch['pixel_values'].to(fabric.device)
features = model.get_image_features(pixel_values)
gallery_features.append(features)
for path, caption in zip(batch['image_path'], batch['caption']):
gallery_items.append({
'image_path': path,
'caption': caption
})
gallery_features = torch.cat(gallery_features, dim=0)
# 쿼리 특징 추출
query_features = process_query(model, processor, args.query, fabric.device, is_image_query, args.new_max_token)
# 유사도 계산 및 상위 1개 결과만 가져오기
similarities = compute_similarities(model, query_features, gallery_features)
top_k = 1 # 하나의 결과만 가져오도록 변경
top_k_similarities, top_k_indices = torch.topk(similarities[0].float(), k=top_k)
# 결과 정리
results = []
top_k_similarities_np = top_k_similarities.cpu().numpy()
top_k_indices_np = top_k_indices.cpu().numpy()
for sim, idx in zip(top_k_similarities_np, top_k_indices_np):
item = gallery_items[idx]
results.append((
item['image_path'],
sim,
item['caption']
))
# 결과 시각화
visualize_results(
args.query,
results,
args.output_path,
is_image_query
)
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
parser = argparse.ArgumentParser()
parser.add_argument("--query", type=str, required=True,
help="Path to query image or text query")
parser.add_argument("--gallery_json", type=str, required=True,
help="Path to JSON file containing gallery information")
parser.add_argument("--output_path", type=str, required=True,
help="Path to save visualization")
parser.add_argument("--model", type=str, default='L',
choices=['L-336', 'L', 'B', 'G'])
parser.add_argument("--ckpt", type=str, default='',
help="Path to custom checkpoint")
parser.add_argument("--new_max_token", type=int, default=248,
help="Maximum number of text tokens")
args = parser.parse_args()
main(args)