| import torch |
| import os |
| import numpy as np |
| import faiss |
| import open_clip |
| import functools |
| import re |
| from tqdm import tqdm |
| import ipdb |
|
|
| from torch.utils.data import DataLoader |
|
|
|
|
| def contains_special_characters(text): |
| |
| if re.search(r'[^\x00-\x7F]', text): |
| return True |
| return False |
|
|
| def check_texts_for_special_characters(texts): |
| results = [] |
| for i, text in enumerate(texts): |
| if contains_special_characters(text): |
| results.append(f"Text {i}: Contains special characters") |
| return results |
|
|
| def clean_text(text): |
| |
| text = re.sub(r'[^\x00-\x7F]+', '', text) |
| |
| text = re.sub(r'\s+', ' ', text) |
| |
| text = text.strip() |
| return text |
|
|
| def clean_texts(texts): |
| return [clean_text(text) for text in texts] |
|
|
|
|
|
|
|
|
| def load_ori_query(coco_class_path): |
| with open(coco_class_path, 'r') as file: |
| coco_classes = [line.strip() for line in file.readlines()] |
|
|
| def add_article_to_classes(class_list): |
| result = [] |
| for item in class_list: |
| |
| if item[0].lower() in 'aeiou': |
| result.append(f"an {item}") |
| else: |
| result.append(f"a {item}") |
| return result |
|
|
| a_cls_list = add_article_to_classes(coco_classes) |
|
|
| an_image_showing_list = [f"an image showing {cls}" for cls in coco_classes] |
|
|
| return a_cls_list, an_image_showing_list |
|
|
|
|
|
|
| def load_index(index_dir): |
| print(os.getcwd()) |
| index_path = os.path.join(index_dir, 'faiss_IVPQ_PCA.index') |
| index = faiss.read_index(index_path) |
|
|
| |
| norm1 = faiss.read_VectorTransform(os.path.join(index_dir, 'norm1.bin')) |
| do_pca = os.path.exists(os.path.join(index_dir, 'pca.bin')) |
| if do_pca: |
| pca = faiss.read_VectorTransform(os.path.join(index_dir, 'pca.bin')) |
| norm2 = faiss.read_VectorTransform(os.path.join(index_dir, 'norm2.bin')) |
|
|
| def feat_transform(x): |
| x = norm1.apply_py(x) |
| if do_pca: |
| x = pca.apply_py(x) |
| x = norm2.apply_py(x) |
| return x |
|
|
| img_ids = np.load(os.path.join(index_dir, 'img_ids.npy')) |
|
|
| return index, feat_transform, img_ids |
|
|
|
|
| def load_model(config_name, weight_path): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model, _, transform = open_clip.create_model_and_transforms(config_name, pretrained=weight_path) |
| tokenizer = open_clip.get_tokenizer(config_name) |
|
|
| if device == 'cpu': |
| model = model.float().to(device) |
| else: |
| model = model.to(device) |
| model.eval() |
| return model, tokenizer |
|
|
|
|
|
|
|
|
| def get_text_list_feature(query_list, ai_config, weight_path): |
| ''' |
| query_list: n classes, each class has k queries ! |
| ''' |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model, tokenizer = load_model(ai_config, weight_path) |
|
|
| |
| text_list = [tokenizer(query).to(device) for query in query_list] |
|
|
| with torch.no_grad(): |
| text_feats = [model.encode_text(text) for text in text_list] |
|
|
| text_feats = [text.cpu().numpy() for text in text_feats] |
| return text_feats |
|
|
|
|
|
|
|
|
| def get_text_feature(query_list, ai_config, weight_path): |
| ''' |
| query_list: n queries ! |
| ''' |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model, tokenizer = load_model(ai_config, weight_path) |
|
|
| text_list = tokenizer(query_list).to(device) |
|
|
| num = text_list.shape[0] |
| batch_size = 1000 |
|
|
| with torch.no_grad(): |
| text_feats = [] |
| for i in tqdm(range(0, num, batch_size)): |
| text_feats.append(model.encode_text(text_list[i:i + batch_size])) |
| |
| text_feats = torch.cat(text_feats, dim=0) |
|
|
| del model |
| torch.cuda.empty_cache() |
|
|
| return text_feats.cpu().numpy() |
|
|
|
|
|
|
|
|
| def print_scores(aesthetics, faiss_smi): |
| |
| aesthetics = np.array(aesthetics) |
| average_aesthetics = np.around(np.mean(aesthetics, axis=0), decimals=3) |
|
|
| faiss_smi = np.array(faiss_smi) |
| average_similarities = np.around(np.mean(faiss_smi, axis=0), decimals=3) |
|
|
| avg_aes, std_aes = np.mean(aesthetics), np.std(aesthetics) |
| avg_smi, std_smi = np.mean(faiss_smi), np.std(faiss_smi) |
| |
| print("avg aesthetics for each completion:", ' '.join(map(str, average_aesthetics))) |
| print("avg aesthetics over all images: {:.3f}".format(avg_aes)) |
| print("std aesthetics over all images: {:.3f}".format(std_aes)) |
| print("avg similarities for each completion:", ' '.join(map(str, average_similarities))) |
| print("avg similarities over all images: {:.3f}".format(avg_smi)) |
| print("std similarities over all images: {:.3f}".format(std_smi)) |
| print("---------------------------------------------------------------------------") |
|
|
|
|
|
|
| def print_scores_iqa(aesthetics, faiss_smi, iqas): |
| |
| aesthetics = np.array(aesthetics) |
| average_aesthetics = np.around(np.mean(aesthetics, axis=0), decimals=3) |
|
|
| faiss_smi = np.array(faiss_smi) |
| average_similarities = np.around(np.mean(faiss_smi, axis=0), decimals=3) |
|
|
| iqas = np.array(iqas) |
| average_iqas = np.around(np.mean(iqas, axis=0), decimals=3) |
| |
| |
| avg_aes, std_aes = np.mean(aesthetics), np.std(aesthetics) |
| avg_smi, std_smi = np.mean(faiss_smi), np.std(faiss_smi) |
| avg_iqa, std_iqa = np.mean(iqas), np.std(iqas) |
| |
| print("avg aesthetics for each completion:", ' '.join(map(str, average_aesthetics))) |
| print("avg aesthetics over all images: {:.3f}".format(avg_aes)) |
| print("std aesthetics over all images: {:.3f}".format(std_aes)) |
| print("avg similarities for each completion:", ' '.join(map(str, average_similarities))) |
| print("avg similarities over all images: {:.3f}".format(avg_smi)) |
| print("std similarities over all images: {:.3f}".format(std_smi)) |
| |
| print("avg IQA for each completion:", ' '.join(map(str, average_iqas))) |
| print("avg IQA over all images: {:.3f}".format(avg_iqa)) |
| print("std IQA over all images: {:.3f}".format(std_iqa)) |
| print("---------------------------------------------------------------------------") |
|
|
|
|
|
|
| def get_scores(img_list, dis_list, loaded_data, img_ids): |
|
|
| aesthetics_score = loaded_data["aesthetics_score"] |
| strImagehash = loaded_data["strImagehash"] |
|
|
| img_hash_list = [] |
| for imgs in img_list: |
| img_hash = [[img_ids[idx] for idx in img] for img in imgs] |
| img_hash_list.append(img_hash) |
|
|
| aesthetics = [] |
| for each_class in img_hash_list: |
| avg_aesthetic = [] |
| for each_completion in each_class: |
| aes_score = [] |
| |
| |
|
|
| indices = [strImagehash.index(s) if s in strImagehash else None for s in each_completion] |
| aes_score = [aesthetics_score[iii] if iii is not None else aesthetics_score.mean() for iii in indices] |
| |
| aes_score = torch.stack(aes_score) |
| |
| avg_aesthetic.append(aes_score.mean()) |
| aesthetics.append(torch.stack(avg_aesthetic)) |
| aesthetics = torch.stack(aesthetics) |
|
|
| faiss_smi = [[each_completion.mean() for each_completion in each_class] for each_class in dis_list] |
| faiss_smi = torch.tensor(faiss_smi) |
|
|
| return aesthetics, faiss_smi, img_hash_list |
|
|
|
|
|
|
|
|
| def get_scores_prompt(img_list, dis_list, loaded_data, img_ids): |
|
|
| aesthetics_score = loaded_data["aesthetics_score"] |
| strImagehash = loaded_data["strImagehash"] |
|
|
| img_hash_list = [] |
| for imgs in img_list: |
| img_hash = [[img_ids[idx] for idx in img] for img in imgs] |
| img_hash_list.append(img_hash) |
|
|
| aesthetics_all = [] |
| for each_class in img_hash_list: |
| aesthetic = [] |
| for each_completion in each_class: |
| aes_score = [] |
| |
| |
|
|
| indices = [strImagehash.index(s) if s in strImagehash else None for s in each_completion] |
| aes_score = [aesthetics_score[iii] if iii is not None else aesthetics_score.mean() for iii in indices] |
| |
| aes_score = torch.stack(aes_score) |
| |
| aesthetic.append(aes_score) |
| aesthetics_all.append(torch.stack(aesthetic)) |
| aesthetics_all = torch.stack(aesthetics_all) |
|
|
| faiss_smi = [[each_completion for each_completion in each_class] for each_class in dis_list] |
| faiss_smi = torch.tensor(faiss_smi) |
|
|
| return aesthetics_all, faiss_smi |
|
|
|
|
|
|
| def image_retrive(sear_k, index, q_feats, loaded_data, img_ids): |
|
|
| img_list = [] |
| dis_list = [] |
| for q_feat in q_feats: |
| D, I = index.search(q_feat, sear_k) |
| img_list.append(I) |
| dis_list.append(D) |
|
|
| aesthetics, faiss_smi, img_hash_list = get_scores(img_list, dis_list, loaded_data, img_ids) |
| |
| |
| print_scores(aesthetics, faiss_smi) |
| return img_hash_list, dis_list |
|
|
|
|
|
|
|
|
|
|
| def image_retrive_prompt(sear_k, index, q_feats, loaded_data, img_ids): |
| img_list = [] |
| dis_list = [] |
| for q_feat in q_feats: |
| D, I = index.search(q_feat, sear_k) |
| img_list.append(I) |
| dis_list.append(D) |
| ipdb.set_trace() |
|
|
| aesthetics, faiss_smi = get_scores_prompt(img_list, dis_list, loaded_data, img_ids) |
| return aesthetics.squeeze().squeeze(), faiss_smi.squeeze().squeeze() |
|
|
|
|
|
|
| def get_faiss_sim(sear_k, index, q_feats, img_ids, use_gpu): |
|
|
| if use_gpu: |
| res = faiss.StandardGpuResources() |
| index = faiss.index_cpu_to_gpu(res, 0, index) |
|
|
| |
| num = q_feats.shape[0] |
| batch_size = 100000 |
| |
| img_hash_list = [] |
| faiss_smi = [] |
|
|
| for i in tqdm(range(0, num, batch_size)): |
| D, I = index.search(q_feats[i:i + batch_size], sear_k) |
| img_hash_list.append(img_ids[I.squeeze()]) |
| faiss_smi.append(torch.from_numpy(D.squeeze())) |
|
|
| faiss_smi = torch.cat(faiss_smi, dim=0) |
| |
| return faiss_smi, img_hash_list |
|
|
| D, I = index.search(q_feats, sear_k) |
| img_hash_list = img_ids[I.squeeze()] |
| faiss_smi = torch.from_numpy(D.squeeze()) |
| |
| return faiss_smi, img_hash_list |
|
|