QCQC / src /generate_relevance_scores.py
Johnny050407's picture
Upload folder using huggingface_hub
9ed01de verified
import os
import torch
import pickle
import sys
import argparse
from utils import load_index, get_text_feature, get_faiss_sim
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
def resolve_path(*parts):
return os.path.abspath(os.path.join(BASE_DIR, *parts))
from search_preparation import index_configs
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Dataset-condition query completion")
parser.add_argument('--texts', nargs='+', default=['coco'], help='Choices of captions')
parser.add_argument('--search_k', type=int, default=1)
args = parser.parse_args()
model_id = 'CLIP-Huge-Flickr-Flat'
out_dir = resolve_path('../', 'coco_faiss_indexes')
index_dir = os.path.join(out_dir, model_id)
if not os.path.exists(index_dir):
raise FileNotFoundError(f"Index directory not found: {index_dir}")
index, feat_transform, img_ids = load_index(index_dir)
ai_config = index_configs[model_id]['a1_config']
weight_path = index_configs[model_id]['weight_path']
text_dir = resolve_path('../', 'processed_data', 'coco')
save_text = True
loaded_data = torch.load(os.path.join(text_dir, 'data_aes.pt'))
loaded_data2 = torch.load(os.path.join(text_dir, 'data_IQA.pt'))
# consistency check
if loaded_data['captions'] != loaded_data2['captions'] or \
loaded_data['image_ids'] != loaded_data2['image_ids']:
print("Two datasets are not aligned. Please check preprocessing.")
sys.exit(1)
texts = loaded_data['captions']
image_ids = loaded_data['image_ids']
# extract text features
q_feats = get_text_feature(texts, ai_config, weight_path)
q_feats = feat_transform(q_feats)
faiss_smi, img_hash_list = get_faiss_sim(args.search_k, index, q_feats, img_ids, use_gpu=True)
if save_text:
data ={
"texts": texts,
"faiss_sim": faiss_smi,
"faiss_img": img_hash_list,
"aesthetics": loaded_data['aesthetics'],
"image_ids": image_ids,
"IQAs": loaded_data2['IQAs']
}
torch.save(data, os.path.join(text_dir, 'data.pt'))
print(f"data saved in {text_dir}!")