| import torch |
| import logging |
|
|
| import numpy as np |
|
|
| from torch.nn import functional as F |
| from typing import Dict, Optional |
|
|
| from tevatron.tree import dec_2d |
| from tevatron.modeling import EncoderDecoderModelForSeq2SeqLM |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class GLENP2Model(EncoderDecoderModelForSeq2SeqLM): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| def get_device(self): |
| """Get the device of the model parameters""" |
| return next(self.lm_p.parameters()).device |
|
|
| def encode_passage(self, psg: Optional[Dict]): |
| """Encode passage""" |
| if psg is None: |
| return None, None, None |
|
|
| past_key_values, encoder_outputs = None, None |
| decoder_inputs_embeds = self.lm_p.get_input_embeddings()( |
| psg["decoder_input_ids"] |
| ) |
| decoder_attention_mask_full = torch.ones( |
| psg["input_ids"].shape[0], |
| self.model_args.num_multi_vectors, |
| dtype=torch.long, |
| device=decoder_inputs_embeds.device, |
| ) |
|
|
| p_reps = [] |
| for i in range(self.model_args.num_multi_vectors): |
| decoder_attention_mask = decoder_attention_mask_full[:, : i + 1] |
| psg_out = self.lm_p( |
| input_ids=psg["input_ids"], |
| attention_mask=psg["attention_mask"], |
| decoder_inputs_embeds=decoder_inputs_embeds, |
| decoder_attention_mask=decoder_attention_mask, |
| return_dict=True, |
| encoder_outputs=encoder_outputs, |
| output_hidden_states=True, |
| use_cache=True, |
| past_key_values=past_key_values, |
| ) |
| if encoder_outputs is None: |
| encoder_outputs = psg_out.encoder_hidden_states |
| past_key_values = psg_out.past_key_values |
| decoder_inputs_embeds = psg_out.decoder_hidden_states[-1][:, -1:, :] |
| p_reps.append(psg_out.decoder_hidden_states[-1][:, -1:, :]) |
| p_reps = torch.cat(p_reps, dim=1) |
|
|
| p_reps = p_reps * (p_reps.size(-1) ** -0.5) |
| p_reps_dt = p_reps.clone() |
|
|
| lm_logits = p_reps @ self.lm_p.shared.weight.T |
|
|
| if self.model_args.mask_special_tokens_for_decoding: |
| special_token_ids = self.tokenizer.all_special_ids |
| special_token_ids = [ |
| x |
| for x in special_token_ids |
| if x |
| not in [ |
| self.tokenizer.bos_token_id, |
| self.tokenizer.eos_token_id, |
| self.tokenizer.pad_token_id, |
| ] |
| ] |
| lm_logits[:, :, special_token_ids] = -float("inf") |
|
|
| if self.model_args.do_docid_temperature_annealing: |
| first_temperature = self.model_args.docid_temperature |
| cur_epoch = ( |
| self.trainer.state.epoch if hasattr(self, "trainer") else self.cur_epoch |
| ) |
| temperature = max( |
| self.model_args.docid_temperature_min, |
| first_temperature * np.exp(-cur_epoch), |
| ) |
| else: |
| temperature = self.model_args.docid_temperature |
|
|
| lm_logits = lm_logits / temperature |
| lm_attention = torch.softmax(lm_logits, dim=-1) |
| p_reps = lm_attention @ self.lm_p.shared.weight |
|
|
| return p_reps, lm_attention, p_reps_dt |
|
|
| def encode_query(self, qry: Optional[Dict]): |
| """Encode query""" |
| if qry is None: |
| return None |
|
|
| past_key_values, encoder_outputs = None, None |
| decoder_inputs_embeds = self.lm_q.get_input_embeddings()( |
| qry["decoder_input_ids"] |
| ) |
| decoder_attention_mask_full = torch.ones( |
| qry["input_ids"].shape[0], |
| self.model_args.num_multi_vectors, |
| dtype=torch.long, |
| device=decoder_inputs_embeds.device, |
| ) |
|
|
| q_reps = [] |
| for i in range(self.model_args.num_multi_vectors): |
| decoder_attention_mask = decoder_attention_mask_full[:, : i + 1] |
| qry_out = self.lm_q( |
| input_ids=qry["input_ids"], |
| attention_mask=qry["attention_mask"], |
| decoder_inputs_embeds=decoder_inputs_embeds, |
| decoder_attention_mask=decoder_attention_mask, |
| return_dict=True, |
| encoder_outputs=encoder_outputs, |
| output_hidden_states=True, |
| use_cache=True, |
| past_key_values=past_key_values, |
| ) |
| if encoder_outputs is None: |
| encoder_outputs = qry_out.encoder_hidden_states |
| past_key_values = qry_out.past_key_values |
| decoder_inputs_embeds = qry_out.decoder_hidden_states[-1][:, -1:, :] |
| q_reps.append(qry_out.decoder_hidden_states[-1][:, -1:, :]) |
| q_reps = torch.cat(q_reps, dim=1) |
| q_reps = q_reps * (q_reps.size(-1) ** -0.5) |
| return q_reps |
|
|
| def compute_similarity(self, q_reps: torch.Tensor, p_reps: torch.Tensor): |
| """Compute similarity between query and passage""" |
| |
| q_reps = q_reps.permute(1, 0, 2) |
| p_reps = p_reps.permute(1, 0, 2) |
| scores = torch.bmm(q_reps, p_reps.permute(0, 2, 1)).permute( |
| 1, 0, 2 |
| ) |
| scores = scores.sum(dim=1) / self.model_args.num_multi_vectors |
| return scores |
|
|
| def make_doc_id(self, batch: Dict): |
| """Process document and make doc_id""" |
| self.eval() |
| with torch.no_grad(): |
| past_key_values, encoder_outputs = None, None |
| device = self.get_device() |
| decoder_inputs_embeds = self.lm_p.get_input_embeddings()( |
| torch.tensor([0], dtype=torch.long, device=device) |
| ) |
| decoder_inputs_embeds = decoder_inputs_embeds.unsqueeze(0).repeat( |
| batch["source_ids"].shape[0], 1, 1 |
| ) |
| decoder_attention_mask_full = torch.ones( |
| batch["source_ids"].shape[0], |
| self.model_args.max_output_length - 1, |
| dtype=torch.long, |
| device=device, |
| ) |
| outs, out_logits = [], [] |
| for i in range(self.model_args.max_output_length - 1): |
| decoder_attention_mask = decoder_attention_mask_full[:, : i + 1] |
| psg_out = self.lm_p( |
| input_ids=batch["source_ids"].to(device), |
| attention_mask=batch["source_mask"].to(device), |
| decoder_inputs_embeds=decoder_inputs_embeds, |
| decoder_attention_mask=decoder_attention_mask, |
| return_dict=True, |
| encoder_outputs=encoder_outputs, |
| output_hidden_states=True, |
| use_cache=True, |
| past_key_values=past_key_values, |
| ) |
| if encoder_outputs is None: |
| encoder_outputs = psg_out.encoder_hidden_states |
| past_key_values = psg_out.past_key_values |
| decoder_inputs_embeds = psg_out.decoder_hidden_states[-1][:, -1:, :] |
| if self.model_args.mask_special_tokens_for_decoding: |
| psg_out.logits[:, :, self.model_args.special_token_ids] = -float( |
| "inf" |
| ) |
| out = psg_out.logits[:, -1, :].argmax(dim=-1) |
| out_logit = ( |
| psg_out.logits[:, -1, :].gather(1, out.unsqueeze(-1)).squeeze(-1) |
| ) |
| outs.append(out.cpu().numpy()) |
| out_logits.append(out_logit.cpu().detach().numpy()) |
| outs = np.stack(outs, axis=1) |
| out_logits = np.stack(out_logits, axis=1) |
|
|
| preds = [] |
| for ids in outs: |
| preds.append( |
| "<->".join( |
| self.tokenizer.convert_ids_to_tokens( |
| ids, skip_special_tokens=False |
| ) |
| ) |
| ) |
|
|
| texts = [ |
| self.tokenizer.decode(ids.numpy(), skip_special_tokens=True) |
| for ids in batch["source_ids"] |
| ] |
| oldids = [oldid for oldid in batch["oldids"]] |
| out_logits = np.round(out_logits.astype(np.float64), 4) |
|
|
| return oldids, texts, preds, out_logits |
|
|
| def evaluation_step(self, batch: Dict, predefined_id: bool = False): |
| """ |
| Evaluate the model on a batch when validation during training. |
| When inference after training, predefined_id is False. |
| When validation during training, predefined_id is True. |
| |
| Args: |
| batch (Dict): |
| batch data (source_ids, source_mask, aug_source_ids, aug_source_mask, target_ids, target_mask, rank, oldid) |
| predefined_id (bool): |
| whether ground truth query exists (True: validation during training, False: inference after training) |
| If decoder_input doc_rep, logits are used for breaking ties |
| |
| Returns: |
| texts (List[str]): |
| list of source texts |
| preds (List[str]): |
| list of predicted query |
| target_docids (List[str]): |
| list of target docids |
| ranks (List[str]): |
| list of ranks |
| """ |
| |
|
|
| |
| max_output_length = ( |
| self.model_args.max_output_length |
| if predefined_id |
| else self.model_args.max_output_length - 1 |
| ) |
|
|
| num_return_sequences = self.model_args.num_return_sequences |
|
|
| with torch.no_grad(): |
| past_key_values, encoder_outputs = None, None |
| K = num_return_sequences * 2 |
| device = self.get_device() |
| decoder_inputs_embeds = self.lm_p.get_input_embeddings()( |
| torch.tensor([0], dtype=torch.long, device=device) |
| ) |
| decoder_inputs_embeds = decoder_inputs_embeds.unsqueeze(0).repeat( |
| batch["source_ids"].shape[0], 1, 1 |
| ) |
| decoder_attention_mask_full = torch.ones( |
| batch["source_ids"].shape[0], |
| self.model_args.max_output_length - 1, |
| dtype=torch.long, |
| device=device, |
| ) |
| decode_tree = self.root if self.model_args.tree else None |
|
|
| outs, out_logits, all_next_token_logits = [], [], [] |
| for i in range(max_output_length): |
| decoder_attention_mask = decoder_attention_mask_full[:, : i + 1] |
| psg_out = self.lm_p( |
| input_ids=batch["source_ids"].to(device), |
| attention_mask=batch["source_mask"].to(device), |
| decoder_inputs_embeds=decoder_inputs_embeds, |
| decoder_attention_mask=decoder_attention_mask, |
| return_dict=True, |
| encoder_outputs=encoder_outputs, |
| output_hidden_states=True, |
| use_cache=True, |
| past_key_values=past_key_values, |
| ) |
| if encoder_outputs is None: |
| encoder_outputs = psg_out.encoder_hidden_states |
| past_key_values = psg_out.past_key_values |
| decoder_inputs_embeds = psg_out.decoder_hidden_states[-1][:, -1:, :] |
|
|
| if ( |
| not predefined_id |
| and self.model_args.mask_special_tokens_for_decoding |
| ): |
| psg_out.logits[:, :, self.model_args.special_token_ids] = -float( |
| "inf" |
| ) |
|
|
| next_token_logits = psg_out.logits[:, -1, :] |
| all_next_token_logits.append( |
| next_token_logits.detach().cpu().unsqueeze(1) |
| ) |
|
|
| batch_size, vocab_size = next_token_logits.shape |
| scores = F.log_softmax(next_token_logits, dim=-1) |
| if i == 0: |
| mask = torch.ones_like(scores) * -float("inf") |
| candidates = list(decode_tree.children.keys()) |
| mask[:, candidates] = 0 |
| scores += mask |
|
|
| out_prob, out_index = scores.topk( |
| K, dim=-1 |
| ) |
| cur_prob, cur_index = out_prob, out_index |
| all_index = cur_index.unsqueeze(-1) |
| else: |
| out_prob = scores |
| all_prob = cur_prob[:, :, None] + out_prob.unsqueeze(1) |
| mask = torch.ones_like(all_prob) * -float("inf") |
|
|
| for b in range(batch_size): |
| for k in range(K): |
| previous_index = all_index[b, k, :].tolist() |
| cur = decode_tree |
| for value in previous_index: |
| if value not in cur.children: |
| next_candidates = [1] |
| break |
| else: |
| cur = cur.children[value] |
| else: |
| next_candidates = list(cur.children.keys()) |
| mask[b, k, next_candidates] = 0 |
| all_prob += mask |
|
|
| top_prob, top_index = all_prob.view(batch_size, -1).topk( |
| K, dim=-1 |
| ) |
| before_index = torch.div( |
| top_index, vocab_size, rounding_mode="floor" |
| ) |
| after_index = top_index % vocab_size |
|
|
| before_index = all_index.gather( |
| 1, before_index.unsqueeze(-1).repeat(1, 1, all_index.shape[-1]) |
| ) |
| all_index = torch.cat( |
| [before_index, after_index.unsqueeze(-1)], dim=-1 |
| ) |
| cur_prob = top_prob |
|
|
| all_index = all_index[:, :num_return_sequences, :] |
| outs = all_index.reshape( |
| -1, max_output_length |
| ) |
| outs = outs.detach().cpu() |
|
|
| if predefined_id: |
| dec = [] |
| for out in outs: |
| if self.tokenizer.eos_token_id in out: |
| out = out[: list(out).index(self.tokenizer.eos_token_id)] |
| dec.append( |
| "-".join( |
| self.tokenizer.convert_ids_to_tokens( |
| out, skip_special_tokens=True |
| ) |
| ) |
| ) |
| else: |
| |
| out_logits = torch.zeros_like( |
| outs, dtype=torch.float32 |
| ) |
| all_next_token_logits = ( |
| torch.cat(all_next_token_logits, dim=1).detach().cpu() |
| ) |
| for b in range(batch_size): |
| outs_b = outs[ |
| b * num_return_sequences : (b + 1) * num_return_sequences |
| ] |
| all_next_token_logits_b = all_next_token_logits[b : b + 1].repeat( |
| num_return_sequences, 1, 1 |
| ) |
| out_logits_b = all_next_token_logits_b.gather( |
| 2, outs_b.unsqueeze(-1) |
| ).squeeze( |
| -1 |
| ) |
| out_logits[ |
| b * num_return_sequences : (b + 1) * num_return_sequences |
| ] = out_logits_b |
|
|
| dec, temp_dec = [], [] |
| for out_cnt, ids in enumerate(outs): |
| pred_id = "<->".join( |
| self.tokenizer.convert_ids_to_tokens( |
| ids, skip_special_tokens=False |
| ) |
| ) |
| out_logits_i = out_logits[out_cnt] |
|
|
| if self.docid2num_docs[pred_id] == 1: |
| oldid = self.docid2oldids[pred_id][0] |
| temp_dec.append(oldid + "<->" + pred_id) |
| else: |
| if self.model_args.reranking == "random": |
| cand_oldids = self.docid2oldids[pred_id] |
| cand_oldids = np.random.permutation(cand_oldids) |
| pred_id_list = [ |
| oldid + "<->" + pred_id for oldid in cand_oldids |
| ] |
| temp_dec += pred_id_list |
| elif self.model_args.reranking in ["cosine", "mse"]: |
| cand_oldids = self.docid2oldids[pred_id] |
| rank_scores = [] |
| for cand_oldid in cand_oldids: |
| docid_logit = self.oldid2docid_logit[cand_oldid] |
| if self.model_args.reranking == "cosine": |
| rank_score = torch.cosine_similarity( |
| docid_logit.unsqueeze(0), |
| out_logits_i.unsqueeze(0), |
| ).item() |
| elif self.model_args.reranking == "mse": |
| rank_score = ( |
| (docid_logit - out_logits_i) |
| .pow(2) |
| .mean() |
| .item() |
| ) |
| rank_scores.append(rank_score) |
| rank_scores = np.array(rank_scores) |
| order_index = np.argsort(-rank_scores) |
| ordered_cand_oldids = np.array(cand_oldids)[order_index] |
| pred_id_list = [ |
| oldid + "<->" + pred_id for oldid in ordered_cand_oldids |
| ] |
| temp_dec += pred_id_list |
| else: |
| raise NotImplementedError( |
| "reranking method should be one of ['random', 'cosine', 'mse']" |
| ) |
|
|
| if out_cnt % num_return_sequences == num_return_sequences - 1: |
| temp_dec = temp_dec[:num_return_sequences] |
| dec += temp_dec |
| temp_dec = [] |
|
|
| targets = [] |
| for oldid in batch["oldid"]: |
| targets.append(oldid + "<->" + self.oldid2docid[oldid]) |
| batch["rank"][0][0] = targets |
|
|
| texts = [ |
| self.tokenizer.decode(ids, skip_special_tokens=True) |
| for ids in batch["source_ids"] |
| ] |
| preds = dec_2d(dec, num_return_sequences) |
| target_docids = batch["rank"][0][0] |
| ranks = [str(a.item()) for a in batch["rank"][0][1]] |
|
|
| return texts, preds, target_docids, ranks |
|
|