| import os |
| import sys |
| import argparse |
| import numpy as np |
| import torch |
| from utils import * |
| from datasets import Dataset |
| from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed, GPT2Tokenizer, GPT2LMHeadModel |
| from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList |
|
|
|
|
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
| if BASE_DIR not in sys.path: |
| sys.path.insert(0, BASE_DIR) |
| sys.path.insert(0, os.path.join(BASE_DIR, "..")) |
|
|
| from search_preparation import index_configs |
|
|
|
|
| def resolve_path(*parts): |
| return os.path.abspath(os.path.join(BASE_DIR, *parts)) |
|
|
|
|
| def load_index_f(index_dir): |
| print(os.getcwd()) |
| index_path = os.path.join(index_dir, 'faiss_IVPQ_PCA.index') |
| index = faiss.read_index(index_path) |
|
|
| norm1 = faiss.read_VectorTransform(os.path.join(index_dir, "norm1.bin")) |
| do_pca = os.path.exists(os.path.join(index_dir, "pca.bin")) |
| if do_pca: |
| pca = faiss.read_VectorTransform(os.path.join(index_dir, "pca.bin")) |
| norm2 = faiss.read_VectorTransform(os.path.join(index_dir, "norm2.bin")) |
|
|
| def feat_transform(x): |
| x = norm1.apply_py(x) |
| if do_pca: |
| x = pca.apply_py(x) |
| x = norm2.apply_py(x) |
| return x |
|
|
| img_ids = torch.load(os.path.join(index_dir, 'img_ids.pt'), weights_only=False) |
|
|
| return index, feat_transform, img_ids |
|
|
| class StoppingCriteriaSub(StoppingCriteria): |
| def __init__(self, tokenizer, stops = [], encounters=1): |
| super().__init__() |
| self.tokenizer = tokenizer |
| self.stops = [stop.to("cuda") for stop in stops] |
|
|
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): |
| current_token = input_ids[0] |
| for stop in self.stops: |
| if self.tokenizer.decode(stop) in self.tokenizer.decode(current_token): |
| return True |
| return False |
|
|
| def query_completion(model_name, ori_query, args, text_dir): |
|
|
| if model_name == "gpt2" or model_name == "Qwen": |
| model_save_path = model_name |
| elif model_name == 'no_completion': |
| print(f"No query completion, copy original query {args.cmpl_k} times to keep the shape ! ") |
| return [[item] * args.cmpl_k for item in ori_query] |
| else: |
| |
| suffix = model_name.replace("model_", "", 1) if model_name.startswith("model_") else model_name |
| if "Qwen" in model_name: |
| tmp_path = os.path.join(text_dir, f"qwencoco{suffix}", model_name) |
| else: |
| tmp_path = os.path.join(text_dir, f"gpt2coco{suffix}", model_name) |
| model_save_path = os.path.join(tmp_path, "checkpoint") |
| print(f"loading model from {model_save_path}...") |
| |
| if model_name == "gpt2": |
| tokenizer = GPT2Tokenizer.from_pretrained(model_save_path) |
| model = GPT2LMHeadModel.from_pretrained(model_save_path) |
| tokenizer.pad_token = tokenizer.eos_token |
| elif model_name == "Qwen": |
| model = AutoModelForCausalLM.from_pretrained(model_save_path+"/Qwen2.5-0.5B") |
| tokenizer = AutoTokenizer.from_pretrained(model_save_path+"/Qwen2.5-0.5B") |
| tokenizer.pad_token = tokenizer.eos_token |
| else: |
| if "Qwen" in model_name: |
| model = AutoModelForCausalLM.from_pretrained(model_save_path) |
| tokenizer = AutoTokenizer.from_pretrained(model_save_path) |
| else: |
| tokenizer = GPT2Tokenizer.from_pretrained(model_save_path) |
| model = GPT2LMHeadModel.from_pretrained(model_save_path) |
|
|
| |
| |
| |
| |
| model.eval() |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"Total number of parameters: {total_params}") |
|
|
|
|
|
|
| stop_words = [".", "!", "?"] |
| stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words] |
| stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(tokenizer, stops=stop_words_ids)]) |
|
|
|
|
| if args.no_condition: |
| with torch.no_grad(): |
| autocompleted_queries = [] |
| for _ in range(80): |
| generated_ids = model.generate( |
| do_sample = True, |
| min_length = args.min_len, |
| |
| max_new_tokens = args.max_len, |
| temperature = args.tmpr, |
| top_k = args.top_k, |
| top_p = args.top_p, |
| repetition_penalty = args.rept_pnal, |
| no_repeat_ngram_size = args.no_rept_ngram, |
| eos_token_id = tokenizer.eos_token_id , |
| pad_token_id = tokenizer.pad_token_id, |
| stopping_criteria = stopping_criteria, |
| num_return_sequences = 1, |
| ) |
| autocompleted_queries.append(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) |
| |
| autocompleted_queries = [[item] * args.cmpl_k for item in autocompleted_queries] |
| |
| else: |
| with torch.no_grad(): |
| tokenized_inputs = [tokenizer(text, padding=True, truncation=True, return_tensors="pt") for text in ori_query] |
| autocompleted_queries = [] |
| for input_data in tokenized_inputs: |
| queries_for_each_class = [] |
| for i in range(args.cmpl_k): |
| generated_ids = model.generate( |
| input_data['input_ids'], |
| attention_mask = input_data["attention_mask"], |
| do_sample = True, |
| min_length = args.min_len, |
| |
| max_new_tokens = args.max_len, |
| temperature = args.tmpr, |
| top_k = args.top_k, |
| top_p = args.top_p, |
| repetition_penalty = args.rept_pnal, |
| no_repeat_ngram_size = args.no_rept_ngram, |
| eos_token_id = tokenizer.eos_token_id , |
| pad_token_id = tokenizer.pad_token_id, |
| stopping_criteria = stopping_criteria, |
| return_dict_in_generate=False, |
| num_return_sequences = 1, |
| ) |
| |
| queries_for_each_class.append(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) |
| autocompleted_queries.append(queries_for_each_class) |
| return autocompleted_queries |
|
|
|
|
| def image_retrieve_coco(sear_k, index, q_feats, loaded_data, img_ids): |
|
|
| img_list = [] |
| dis_list = [] |
| |
| D, I = index.search(q_feats, sear_k) |
| |
| aesthetics, faiss_smi, iqas, img_hash_list = get_scores_coco(I, D, loaded_data, img_ids) |
| |
| print_scores_iqa(aesthetics, faiss_smi, iqas) |
| return aesthetics, faiss_smi, iqas, img_hash_list |
|
|
|
|
| def get_scores_coco(I, D, loaded_data, img_ids): |
| aesthetics_score = torch.tensor(loaded_data["aesthetics"]) |
| IQAs_score = torch.tensor(loaded_data["IQAs"]) |
| strImagehash = loaded_data["image_ids"] |
| |
| img_hash = [img_ids[idx] for idx in I] |
| |
| aesthetics = [] |
| iqas = [] |
| for each_class in img_hash: |
| indices = [strImagehash.index(int(s.split(".")[0])) if int(s.split(".")[0]) in strImagehash else None for s in each_class] |
| aes_score = [aesthetics_score[iii] if iii is not None else aesthetics_score.mean() for iii in indices] |
| iqa_score = [IQAs_score[iii] if iii is not None else IQAs_score.mean() for iii in indices] |
|
|
| aes_score = torch.stack(aes_score) |
| aesthetics.append(aes_score) |
| |
| iqa_score = torch.stack(iqa_score) |
| iqas.append(iqa_score) |
| |
| aesthetics = torch.stack(aesthetics) |
| iqas = torch.stack(iqas) |
| faiss_smi = torch.tensor(D) |
|
|
| return aesthetics, faiss_smi, iqas, img_hash |
|
|
|
|
|
|
| if __name__ == '__main__': |
|
|
| parser = argparse.ArgumentParser(description="Dataset-condition query completion") |
|
|
| parser.add_argument('--cmpl_k', type=int, default=10, help='perform k times query completion') |
| parser.add_argument('--sear_k', type=int, default=50, help='search k images for each completed query') |
| |
| parser.add_argument('--min_len', type=int, default=10, help='minimal length in query completion') |
| parser.add_argument('--max_len', type=int, default=20, help='maximal length in query completion') |
| parser.add_argument('--tmpr', type=float, default=0.7, help='temperature in query completion') |
| parser.add_argument('--top_k', type=int, default=50, help='select tok_k tokens in query completion') |
| parser.add_argument('--top_p', type=float, default=0.9, help='top_p in query completion') |
| parser.add_argument('--rept_pnal', type=float, default=1.2, help='repetition_penalty in query completion') |
| parser.add_argument('--no_rept_ngram', type=int, default=2, help='no_repeat_ngram_size in query completion') |
| |
| parser.add_argument('--seed', type=int, default=42, help='seed') |
| |
| parser.add_argument('--search_imgs', action='store_true', help='load pre-downloaded data for image retrieval') |
| parser.add_argument('--an_img_showing', action='store_true', help='query: an image showing x') |
| parser.add_argument('--prt_cmpl_qry', action='store_true', help='print some completed queries for illustration') |
| |
| parser.add_argument('--no_condition', action='store_true', help='generate texts without conditions') |
| parser.add_argument('--model_names', nargs='+', default=['no_completion', 'gpt2'], help='Choices of completion models') |
|
|
| |
| parser.add_argument('--aes_level', type=str, default='high', help='low or high or median') |
| parser.add_argument('--sim_level', type=str, default='high', help='low or high or median') |
| parser.add_argument('--iqa_level', type=str, default='high', help='low or high or median') |
| |
|
|
| |
| args = parser.parse_args() |
| set_seed(args.seed) |
|
|
|
|
| |
| work_space = resolve_path("..") |
| text_dir = resolve_path("..", "outputs") |
| coco_class_path = os.path.join(work_space, "MC-COCO-Class.txt") |
| if not os.path.exists(coco_class_path): |
| raise FileNotFoundError(f"Class list not found: {coco_class_path}") |
| a_cls_list, an_image_showing_list = load_ori_query(coco_class_path) |
|
|
|
|
| index_dir = resolve_path("..", "coco_faiss_indexes", "CLIP-Huge-Flickr-Flat") |
| if not os.path.isdir(index_dir): |
| raise FileNotFoundError(f"Index dir not found: {index_dir}. Run search_preparation.py first.") |
| index, feat_transform, img_ids = load_index(index_dir) |
| print("loaded index, transform, img_ids.") |
|
|
| model_id = 'CLIP-Huge-Flickr-Flat' |
| ai_config = index_configs[model_id]['a1_config'] |
| weight_path = index_configs[model_id]['weight_path'] |
|
|
|
|
| for model_name in args.model_names: |
| print(args) |
|
|
| data_dict = {'text': np.array(a_cls_list)} |
| dataset = Dataset.from_dict(data_dict) |
|
|
| |
| PROMPT_TEMPLATES = { |
| |
| |
| "model_ADS": "<|startoftext|>Aesthetic: {aes}, DeQA-Score: {iqa}, Similarity: {sim}, Query: {cond}", |
| "model_DSA": "<|startoftext|>DeQA-Score: {iqa}, Similarity: {sim}, Aesthetic: {aes}, Query: {cond}", |
| "model_SAD": "<|startoftext|>Similarity: {sim}, Aesthetic: {aes}, DeQA-Score: {iqa}, Query: {cond}", |
| "model_SA": "<|startoftext|>Similarity: {sim}, Aesthetic: {aes}, Query: {cond}", |
| } |
| if model_name not in PROMPT_TEMPLATES: |
| raise ValueError(f"Invalid model name: {model_name}. Valid: {list(PROMPT_TEMPLATES)}") |
| prompt = PROMPT_TEMPLATES[model_name] |
|
|
| def apply_prompt_template(sample): |
| sim = args.sim_level |
| aes = args.aes_level |
| iqa = args.iqa_level |
| con_str = sample["text"] |
| |
| return {"prompt": prompt.format(aes=aes, iqa=iqa, sim=sim, cond=con_str)} |
|
|
| dataset = dataset.map(apply_prompt_template, remove_columns=["text"]) |
| queries = query_completion(model_name, dataset['prompt'], args, text_dir) |
|
|
| |
| autocompleted_queries = [] |
|
|
| for item in queries: |
| query_text = item[0].split('Query: ')[1] |
| autocompleted_queries.append([query_text]) |
|
|
| |
| if args.prt_cmpl_qry: |
| print(f"------------------------- {model_name} --------------------------------") |
| for ii in range(len(autocompleted_queries)): |
| for jj, query in enumerate(autocompleted_queries[ii]): |
| print(query) |
|
|
| if args.search_imgs: |
| data_path = resolve_path("..", "processed_data", "coco", "data.pt") |
| if not os.path.exists(data_path): |
| raise FileNotFoundError(f"Processed data not found: {data_path}") |
| loaded_data = torch.load(data_path, weights_only=False) |
| |
| |
| q_feats = get_text_list_feature(autocompleted_queries, ai_config, weight_path) |
| |
| q_feats = torch.tensor(np.array(q_feats)).squeeze() |
| q_feats /= q_feats.norm(dim=-1, keepdim=True) |
| q_feats = feat_transform(q_feats.numpy()) |
| |
| |
| aesthetics, faiss_smi, iqas, img_hash_list = image_retrieve_coco(args.sear_k, index, q_feats, loaded_data, img_ids) |
|
|
| |
| |