| import collections |
|
|
| from torch.serialization import default_restore_location |
|
|
| from transformers import BertTokenizer, BertModel |
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from tqdm import tqdm |
| import numpy as np |
| import pickle |
| import argparse |
| import csv |
|
|
| nq_temp = {} |
|
|
| CheckpointState = collections.namedtuple("CheckpointState", |
| ['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch', |
| 'encoder_params']) |
|
|
| def load_states_from_checkpoint(model_file: str) -> CheckpointState: |
| state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu')) |
| return CheckpointState(**state_dict) |
|
|
| class DocPool(Dataset): |
| def __init__(self, path): |
| doc = [] |
| with open(path, "r", encoding="utf8") as f: |
| lines = csv.reader(f, delimiter='\t') |
| for _id, _text in lines: |
| doc.append(_text) |
| self.doc = doc |
| |
|
|
| def __len__(self): |
| return len(self.doc) |
|
|
| def __getitem__(self, index): |
| doc = self.doc[index] |
| return index, doc |
|
|
|
|
| def my_collate(batch): |
| batch = list(zip(*batch)) |
| res = {'id': batch[0], 'doc': batch[1]} |
| del batch |
| return res |
|
|
|
|
| def extract_feature(args): |
| torch.manual_seed(2024) |
| torch.cuda.manual_seed(2024) |
| np.random.seed(2024) |
| if(args.doc_or_query == 'doc'): |
| _path = './nq_doc.tsv' |
| _out_path = './doc_embedding.pickle' |
| _prefix = 'ctx_model.' |
| else: |
| _path = './nq_query.tsv' |
| _out_path = './query_embedding.pickle' |
| _prefix = 'question_model.' |
| with torch.no_grad(): |
| doc_dataset = DocPool(_path) |
| print(len(doc_dataset)) |
| doc_dataloader = DataLoader(dataset=doc_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=my_collate) |
| tokenizer = BertTokenizer.from_pretrained(args.model_name) |
| model = BertModel.from_pretrained(args.model_name, return_dict=True) |
| saved_state = load_states_from_checkpoint(args.model_file) |
| prefix_len = len(_prefix) |
| ctx_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if |
| key.startswith(_prefix)} |
| model.load_state_dict(ctx_state, strict=False) |
| model = torch.nn.DataParallel(model) |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
| model = model.to(device) |
| model.eval() |
|
|
| ids = [] |
| idx = 0 |
| doc_feature = np.zeros((len(doc_dataset), 768)) |
|
|
| for batch_data in tqdm(doc_dataloader): |
| doc_id = batch_data['id'] |
| doc_body = batch_data['doc'] |
| inputs = tokenizer(doc_body, padding=True, truncation=True, return_tensors="pt", |
| add_special_tokens=True).to(device) |
| outputs = model(**inputs) |
| pooler_output = outputs.last_hidden_state[:, 0] |
|
|
| ids.extend(doc_id) |
| doc_feature[idx: idx + pooler_output.shape[0]] = pooler_output.cpu().numpy() |
| idx += pooler_output.shape[0] |
|
|
| feature_dic = {} |
| for i, id_i in enumerate(ids): |
| feature_dic[id_i] = doc_feature[i] |
| with open(_out_path, 'wb') as f: |
| pickle.dump(feature_dic, f) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
|
|
| parser.add_argument('--batch_size', type=int, default=512, |
| help='minibatch size') |
| parser.add_argument('--model_name', type=str, default='Luyu/co-condenser-wiki', |
| help='model name') |
| parser.add_argument('--model_file', type=str, default='./dpr_biencoder.38.602', |
| help='model name') |
| parser.add_argument('--doc_or_query', type=str, default='query', |
| help='transfer documents or queries') |
|
|
| args = parser.parse_args() |
| extract_feature(args) |