| from Reasoning.text_retrievers.contriever import Contriever |
| from Reasoning.text_retrievers.ada import Ada |
| from stark_qa import load_qa, load_skb |
|
|
| import pickle as pkl |
| from tqdm import tqdm |
| from transformers import BertTokenizer, BertModel |
|
|
| model_name = f"bert-base-uncased" |
|
|
| tokenizer = BertTokenizer.from_pretrained(model_name) |
| encoder = BertModel.from_pretrained(model_name) |
|
|
|
|
| def get_bm25_scores(dataset_name, bm25, outputs): |
| |
| new_outputs = [] |
| |
| for i in range(len(outputs)): |
| query, q_id, ans_ids = outputs[i]['query'], outputs[i]['q_id'], outputs[i]['ans_ids'] |
| paths= outputs[i]['paths'] |
| rg = outputs[i]['rg'] |
| |
| if dataset_name == 'prime': |
| new_path_dict = paths |
| else: |
| |
| new_path_dict = {} |
| for key in paths.keys(): |
| new_path = [x for x in paths[key] if x != -1] |
| new_path_dict[key] = new_path |
| |
| |
| candidates_ids = [] |
| for key in new_path_dict.keys(): |
| candidates_ids.extend(new_path_dict[key][1:]) |
| candidates_ids.extend(ans_ids) |
| candidates_ids = list(set(candidates_ids)) |
| |
| |
| bm_score_dict = bm25.score(query, q_id, candidate_ids=candidates_ids) |
| outputs[i]['bm_score_dict'] = bm_score_dict |
| |
| |
| bm_vector_dict = outputs[i]['bm_vector_dict'] |
| for key in bm_vector_dict.keys(): |
| if -1 in bm_vector_dict[key]: |
| path = new_path_dict[key] |
| assert len(path) == len(bm_vector_dict[key]) |
| |
| bm_vector_dict[key] = [bm_score_dict[path[j]] if x == -1 else x for j, x in enumerate(bm_vector_dict[key])] |
| |
| |
| outputs[i]['bm_vector_dict'] = bm_vector_dict |
| |
| |
| if dataset_name == 'prime': |
| max_len = 3 |
| new_paths = {} |
| for key in paths: |
| new_path = paths[key] |
| if len(paths[key]) < max_len: |
| new_path = [-1] * (max_len - len(paths[key])) + paths[key] |
| elif len(paths[key]) > max_len: |
| new_path = paths[key][-max_len:] |
| new_paths[key] = new_path |
| |
| |
| outputs[i]['paths'] = new_paths |
| |
| new_outputs.append(outputs[i]) |
| |
| return new_outputs |
|
|
|
|
| def prepare_score_vector_dict(raw_data): |
| |
| for i in range(len(raw_data)): |
| |
| pred_dict = raw_data[i]['pred_dict'] |
| |
| bm_vector_dict = raw_data[i]['bm_vector_dict'] |
| |
| raw_data[i]['score_vector_dict'] = {} |
| |
| for key in pred_dict: |
| |
| bm_vector = bm_vector_dict[key] |
| |
| rk_score = pred_dict[key] |
| |
| score_vector = bm_vector + [rk_score] |
| |
| if len(score_vector) < 4: |
| score_vector = [0] * (4 - len(score_vector)) + score_vector |
| elif len(score_vector) > 4: |
| score_vector = score_vector[-4:] |
| |
| raw_data[i]['score_vector_dict'][key] = score_vector |
| |
| return raw_data |
|
|
|
|
| def prepare_text_emb_symb_enc(raw_data, skb): |
| |
| text2emb_list = [] |
| text2emb_dict = {} |
| |
| symbolic_encode_dict = { |
| 3: [0, 1, 1], |
| 2: [2, 0, 1], |
| 1: [2, 2, 0], |
| } |
| |
| for i in range(len(raw_data)): |
| |
| paths = raw_data[i]['paths'] |
| preds = raw_data[i]['pred_dict'] |
| assert len(paths) == len(preds) |
| |
| |
| raw_data[i]['text_emb_dict'] = {} |
| |
| |
| raw_data[i]['symb_enc_dict'] = {} |
| |
| for key in paths: |
| |
| path = paths[key] |
| |
| text_path_li = [skb.get_node_type_by_id(node_id) if node_id != -1 else "padding" for node_id in path] |
| text_path_str = " ".join(text_path_li) |
| if text_path_str not in text2emb_list: |
| |
| text2emb_list.append(text_path_str) |
| text2emb_dict[text_path_str] = -1 |
| |
| |
| raw_data[i]['text_emb_dict'][key] = text_path_str |
| |
| |
| |
| num_non_1 = len([p for p in path if p != -1]) |
| |
| symb_enc = symbolic_encode_dict[num_non_1] |
| |
| raw_data[i]['symb_enc_dict'][key] = symb_enc |
| |
| |
| for key in text2emb_dict.keys(): |
| |
| text_enc = tokenizer(key, return_tensors='pt')['input_ids'] |
| outputs = encoder(text_enc) |
| last_hidden_states = outputs.last_hidden_state.mean(dim=1) |
| text2emb_dict[key] = last_hidden_states.detach() |
| |
| |
| new_data = {'data': raw_data, 'text2emb_dict': text2emb_dict} |
| |
| return new_data |
|
|
|
|
| def prepare_trajectories(dataset_name, bm25, skb, outputs): |
| |
| new_outputs = get_bm25_scores(dataset_name, bm25, outputs) |
| |
| new_outputs = prepare_score_vector_dict(new_outputs) |
| |
| new_data = prepare_text_emb_symb_enc(new_outputs, skb) |
| |
| return new_data |
|
|
| |
| def get_contriever_scores(dataset_name, mod, skb, path): |
| |
| with open(path, 'rb') as f: |
| data = pkl.load(f) |
|
|
| raw_data = data['data'] |
| |
| |
| qa = load_qa(dataset_name, human_generated_eval=False) |
|
|
| contriever = Contriever(skb, dataset_name, device='cuda') |
|
|
| split_idx = qa.get_idx_split(test_ratio=1.0) |
|
|
| all_indices = split_idx[mod].tolist() |
| |
| for idx, i in enumerate(tqdm(all_indices)): |
| query, q_id, ans_ids, _ = qa[i] |
| assert query == raw_data[idx]['query'] |
| pred_ids = list(raw_data[idx]['pred_dict'].keys()) |
| candidates_ids = list(set(pred_ids)) |
| candidates_ids.extend(ans_ids) |
| |
| |
| contriever_score_dict = contriever.score(query, q_id, candidate_ids=candidates_ids) |
|
|
| raw_data[idx]['contriever_score_dict'] = contriever_score_dict |
| |
| |
| data['data'] = raw_data |
| |
| with open(path, 'wb') as f: |
| pkl.dump(data, f) |
| |
| def get_ada_scores(dataset_name, mod, skb, path): |
| |
| with open(path, 'rb') as f: |
| data = pkl.load(f) |
|
|
| raw_data = data['data'] |
| |
| |
| qa = load_qa(dataset_name, human_generated_eval=False) |
|
|
| ada = Ada(skb, dataset_name, device='cuda') |
|
|
| split_idx = qa.get_idx_split(test_ratio=1.0) |
|
|
| all_indices = split_idx[mod].tolist() |
| |
| for idx, i in enumerate(tqdm(all_indices)): |
| query, q_id, ans_ids, _ = qa[i] |
| assert query == raw_data[idx]['query'] |
| pred_ids = list(raw_data[idx]['pred_dict'].keys()) |
| candidates_ids = list(set(pred_ids)) |
| candidates_ids.extend(ans_ids) |
| |
| |
| ada_score_dict = ada.score(query, q_id, candidate_ids=candidates_ids) |
|
|
| raw_data[idx]['ada_score_dict'] = ada_score_dict |
| |
| |
| data['data'] = raw_data |
| |
| with open(path, 'wb') as f: |
| pkl.dump(data, f) |
|
|
| if __name__ == '__main__': |
| print(f"Test prepare_rerank") |
| |
| |