| import os |
| import os.path as osp |
| import random |
| import sys |
| import argparse |
| import pandas as pd |
|
|
| import torch |
| from tqdm import tqdm |
|
|
| from stark_qa.tools.api_lib.openai_emb import get_contriever, get_contriever_embeddings |
|
|
| sys.path.append('.') |
| from stark_qa import load_skb, load_qa |
| from stark_qa.tools.api import get_api_embeddings |
| from stark_qa.tools.local_encoder import get_llm2vec_embeddings, get_gritlm_embeddings |
| from models.model import get_embeddings |
|
|
| import argparse |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
|
|
| |
| parser.add_argument('--dataset', default='prime', choices=['amazon', 'prime', 'mag']) |
| parser.add_argument('--emb_model', default='contriever', |
| choices=[ |
| 'text-embedding-ada-002', |
| 'text-embedding-3-small', |
| 'text-embedding-3-large', |
| 'voyage-large-2-instruct', |
| 'GritLM/GritLM-7B', |
| 'McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp', |
| 'all-mpnet-base-v2' |
| ] |
| ) |
|
|
| |
| parser.add_argument('--mode', default='query', choices=['doc', 'query']) |
|
|
| |
| parser.add_argument("--data_dir", default="data/", type=str) |
| parser.add_argument("--emb_dir", default="emb/", type=str) |
|
|
| |
| parser.add_argument('--add_rel', action='store_true', default=False, help='add relation to the text') |
| parser.add_argument('--compact', action='store_true', default=False, help='make the text compact when input to the model') |
|
|
| |
| parser.add_argument("--human_generated_eval", action="store_true", help="if mode is `query`, then generating query embeddings on human generated evaluation split") |
|
|
| |
| parser.add_argument("--batch_size", default=1024, type=int) |
|
|
| |
| parser.add_argument("--n_max_nodes", default=None, type=int, metavar="ENCODE") |
| parser.add_argument("--device", default=None, type=str, metavar="ENCODE") |
| parser.add_argument("--peft_model_name", default=None, type=str, help="llm2vec pdft model", metavar="ENCODE") |
| parser.add_argument("--instruction", type=str, help="gritl/llm2vec instruction", metavar="ENCODE") |
|
|
| args = parser.parse_args() |
|
|
| |
| encode_kwargs = {k: v for k, v in vars(args).items() if v is not None and parser._option_string_actions[f'--{k}'].metavar == "ENCODE"} |
|
|
| return args, encode_kwargs |
| |
|
|
| if __name__ == '__main__': |
| args, encode_kwargs = parse_args() |
| args.human_generated_eval = False |
| mode_surfix = '_human_generated_eval' if args.human_generated_eval and args.mode == 'query' else '' |
| mode_surfix += '_no_rel' if not args.add_rel else '' |
| mode_surfix += '_no_compact' if not args.compact else '' |
| emb_dir = osp.join(args.emb_dir, args.dataset, args.emb_model, f'{args.mode}{mode_surfix}') |
| csv_cache = osp.join(args.data_dir, args.dataset, f'{args.mode}{mode_surfix}.csv') |
|
|
| print(f'Embedding directory: {emb_dir}') |
| os.makedirs(emb_dir, exist_ok=True) |
| os.makedirs(os.path.dirname(csv_cache), exist_ok=True) |
|
|
| if args.mode == 'doc': |
| skb = load_skb(args.dataset) |
| lst = skb.candidate_ids |
| emb_path = osp.join(emb_dir, f'candidate_emb_dict.pt') |
| if args.mode == 'query': |
| qa_dataset = load_qa(args.dataset, human_generated_eval=args.human_generated_eval) |
| lst = [qa_dataset[i][1] for i in range(len(qa_dataset))] |
| emb_path = osp.join(emb_dir, f'query_emb_dict.pt') |
| random.shuffle(lst) |
| |
| |
| if osp.exists(emb_path): |
| emb_dict = torch.load(emb_path) |
| exist_emb_indices = list(emb_dict.keys()) |
| print(f'Loaded existing embeddings from {emb_path}. Size: {len(emb_dict)}') |
| else: |
| emb_dict = {} |
| exist_emb_indices = [] |
|
|
| |
| if args.mode == 'doc' and osp.exists(csv_cache): |
| df = pd.read_csv(csv_cache) |
| cache_dict = dict(zip(df['index'], df['text'])) |
|
|
| |
| assert set(cache_dict.keys()) == set(lst), 'Indices in cache do not match the candidate indices.' |
|
|
| indices = list(set(lst) - set(exist_emb_indices)) |
| texts = [cache_dict[idx] for idx in tqdm(indices, desc="Filtering docs for new embeddings")] |
| else: |
| indices = lst |
| texts = [qa_dataset.get_query_by_qid(idx) if args.mode == 'query' |
| else skb.get_doc_info(idx, add_rel=args.add_rel, compact=args.compact) for idx in tqdm(indices, desc="Gathering docs")] |
| if args.mode == 'doc': |
| df = pd.DataFrame({'index': indices, 'text': texts}) |
| df.to_csv(csv_cache, index=False) |
|
|
| print(f'Generating embeddings for {len(texts)} texts...') |
| if args.emb_model == 'contriever': |
| encoder, tokenizer = get_contriever(dataset_name=args.dataset) |
| for i in tqdm(range(0, len(texts), args.batch_size), desc="Generating embeddings"): |
| batch_texts = texts[i:i+args.batch_size] |
| batch_embs = get_contriever_embeddings(batch_texts, encoder=encoder, tokenizer=tokenizer, device='cuda') |
| batch_embs = batch_embs.view(len(batch_texts), -1).cpu() |
| |
| batch_indices = indices[i:i+args.batch_size] |
| for idx, emb in zip(batch_indices, batch_embs): |
| emb_dict[idx] = emb.view(1, -1) |
| else: |
| |
| for i in tqdm(range(0, len(texts), args.batch_size), desc="Generating embeddings"): |
| batch_texts = texts[i:i+args.batch_size] |
| batch_embs = get_embeddings(batch_texts, args.emb_model, **encode_kwargs) |
| batch_embs = batch_embs.view(len(batch_texts), -1).cpu() |
| |
| batch_indices = indices[i:i+args.batch_size] |
| for idx, emb in zip(batch_indices, batch_embs): |
| emb_dict[idx] = emb.view(1, -1) |
| |
| torch.save(emb_dict, emb_path) |
| print(f'Saved {len(emb_dict)} embeddings to {emb_path}!') |
|
|