| import os |
| import time |
| import copy |
| import logging |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.distributed as dist |
| import pandas as pd |
|
|
| from collections import defaultdict |
| from dataclasses import dataclass |
| from torch import nn, Tensor |
| from transformers import PreTrainedModel, AutoModelForSeq2SeqLM, AutoTokenizer |
| from transformers.file_utils import ModelOutput |
| from typing import Dict, Optional |
|
|
| from tevatron.tree import TreeBuilder |
| from tevatron.arguments import ( |
| GLENP2ModelArguments as ModelArguments, |
| GLENP2TrainingArguments as TrainingArguments, |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class EncoderDecoderOutputForSeq2SeqLM(ModelOutput): |
| q_reps: Optional[Tensor] = None |
| p_reps: Optional[Tensor] = None |
| loss: Optional[Tensor] = None |
| scores: Optional[Tensor] = None |
|
|
|
|
| class EncoderDecoderModelForSeq2SeqLM(nn.Module): |
| TRANSFORMER_CLS = AutoModelForSeq2SeqLM |
|
|
| def __init__( |
| self, |
| lm_q: PreTrainedModel, |
| lm_p: PreTrainedModel, |
| untie_encoder: bool = False, |
| negatives_x_device: bool = False, |
| tokenizer=None, |
| ): |
| super().__init__() |
| self.lm_q = lm_q |
| self.lm_p = lm_p |
| self.cross_entropy = nn.CrossEntropyLoss(reduction="mean") |
| self.negatives_x_device = negatives_x_device |
| self.untie_encoder = untie_encoder |
| if self.negatives_x_device: |
| if not dist.is_initialized(): |
| raise ValueError( |
| "Distributed training has not been initialized for representation all gather." |
| ) |
| self.process_rank = dist.get_rank() |
| self.world_size = dist.get_world_size() |
|
|
| self.tokenizer = tokenizer |
| self.gen_r1_list = [] |
|
|
| def forward( |
| self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None |
| ): |
| |
| decoder_start_token_id = self.config.decoder_start_token_id |
| if query is not None: |
| query["decoder_input_ids"] = ( |
| torch.zeros(query["input_ids"].shape[0], 1, dtype=torch.long) |
| .fill_(decoder_start_token_id) |
| .to(query["input_ids"].device) |
| ) |
| query["decoder_attention_mask"] = torch.ones( |
| query["input_ids"].shape[0], 1, dtype=torch.long |
| ).to( |
| query["input_ids"].device |
| ) |
| if passage is not None: |
| passage["decoder_input_ids"] = ( |
| torch.zeros(passage["input_ids"].shape[0], 1, dtype=torch.long) |
| .fill_(decoder_start_token_id) |
| .to(passage["input_ids"].device) |
| ) |
| passage["decoder_attention_mask"] = torch.ones( |
| passage["input_ids"].shape[0], 1, dtype=torch.long |
| ).to( |
| passage["input_ids"].device |
| ) |
|
|
| q_reps = self.encode_query(query) |
|
|
| if passage is None: |
| p_reps, p_reps_dt = None, None |
| else: |
| p_reps, p_attention, p_reps_dt = self.encode_passage(passage) |
|
|
| predid2oldid = self.trainer.train_dataset.predid2oldid |
| oldid2predid = self.trainer.train_dataset.oldid2predid |
|
|
| pred_docids = p_attention.argmax(dim=-1).cpu().tolist() |
|
|
| |
| if ( |
| self.trainer.data_args.negative_passage_type == "self" |
| and p_reps is not None |
| ): |
| for oldid, predid in zip(passage["oldid"].cpu().tolist(), pred_docids): |
| for j in range(1, len(predid) + 1): |
| cur_predid = "-".join([str(x) for x in predid[:j]]) |
| predid2oldid[cur_predid].add(oldid) |
| oldid2predid[oldid] = cur_predid |
|
|
| |
| if q_reps is None or p_reps is None: |
| return EncoderDecoderOutputForSeq2SeqLM(q_reps=q_reps, p_reps=p_reps_dt) |
|
|
| |
| loss, lm_loss = 0, 0 |
| gen_r1 = 0 |
| if self.training: |
| if self.negatives_x_device: |
| q_reps = self._dist_gather_tensor(q_reps) |
| p_reps = self._dist_gather_tensor(p_reps) |
|
|
| scores = self.compute_similarity( |
| q_reps, p_reps |
| ) |
| scores = scores.view(q_reps.size(0), -1) |
|
|
| scores = scores / self.softmax_temperature |
|
|
| target = torch.arange( |
| scores.size(0), device=scores.device, dtype=torch.long |
| ) |
| target = target * (p_reps.size(0) // q_reps.size(0)) |
|
|
| loss = self.compute_loss(scores, target) |
|
|
| if "infonce_loss" in self.model_args.__dict__: |
| loss = loss * self.model_args.infonce_loss |
|
|
| if self.negatives_x_device: |
| loss = loss * self.world_size |
|
|
| lm_logits = q_reps @ self.lm_q.shared.weight.T |
|
|
| if self.model_args.mask_special_tokens_for_decoding: |
| special_token_ids = self.tokenizer.all_special_ids |
| special_token_ids = [ |
| x |
| for x in special_token_ids |
| if x |
| not in [ |
| self.tokenizer.bos_token_id, |
| self.tokenizer.eos_token_id, |
| self.tokenizer.pad_token_id, |
| ] |
| ] |
| lm_logits[:, :, special_token_ids] = -1e9 |
|
|
| lm_logits = lm_logits / self.softmax_temperature |
|
|
| |
| pos_doc = torch.arange(q_reps.size(0), dtype=torch.long) |
| pos_doc = pos_doc * (p_reps.size(0) // q_reps.size(0)) |
| pos_p_attention = p_attention[ |
| pos_doc, :, : |
| ] |
|
|
| lm_targets = pos_p_attention.argmax(dim=-1) |
| if ( |
| "q_to_docid_loss" in self.model_args.__dict__ |
| and self.model_args.q_to_docid_loss > 0 |
| ): |
| lm_loss = self.cross_entropy( |
| lm_logits.view(-1, lm_logits.size(-1)), lm_targets.view(-1) |
| ) |
|
|
| loss += lm_loss * self.model_args.q_to_docid_loss |
|
|
| if ( |
| "cosine_point_loss" in self.model_args.__dict__ |
| and self.model_args.cosine_point_loss > 0 |
| ): |
| pos_doc_lm_logits = ( |
| p_reps_dt[pos_doc] @ self.lm_p.shared.weight.T |
| ) |
| query_lm_logits = lm_logits |
|
|
| pos_doc_id_weight, pos_doc_id = pos_doc_lm_logits.max(dim=-1) |
| query_id_weight = query_lm_logits.gather( |
| dim=-1, index=pos_doc_id.unsqueeze(-1) |
| ).squeeze( |
| -1 |
| ) |
|
|
| cosine_sim = F.cosine_similarity( |
| pos_doc_id_weight, query_id_weight, dim=-1 |
| ) |
| cosine_loss = torch.mean(1 - cosine_sim) |
|
|
| loss += cosine_loss * self.model_args.cosine_point_loss |
|
|
| query_preds = lm_logits.argmax(dim=-1) |
| gen_r1 = (query_preds == lm_targets).float().mean(dim=1) == 1 |
| self.gen_r1_list += gen_r1.cpu().tolist() |
|
|
| |
| else: |
| scores = self.compute_similarity(q_reps, p_reps) |
| loss = None |
| return EncoderDecoderOutputForSeq2SeqLM( |
| loss=loss, scores=scores, q_reps=q_reps, p_reps=p_reps_dt |
| ) |
|
|
| def encode_passage(self, psg): |
| raise NotImplementedError("EncoderDecoderModel is an abstract class") |
|
|
| def encode_query(self, qry): |
| raise NotImplementedError("EncoderDecoderModel is an abstract class") |
|
|
| def compute_similarity(self, q_reps, p_reps): |
| raise NotImplementedError("EncoderModel is an abstract class") |
|
|
| def compute_loss(self, scores, target): |
| return self.cross_entropy(scores, target) |
|
|
| def build_tree(self, log_step: int = None, log_file: str = None): |
| docid_df = pd.read_csv( |
| self.model_args.docid_file_name, |
| sep="\t", |
| names=["oldid", "docid", "docid_logit", "text"], |
| dtype={"oldid": str, "docid": str, "docid_logit": str, "text": str}, |
| ).loc[:, ["oldid", "docid", "docid_logit"]] |
|
|
| lines = [] |
|
|
| |
| time_str = time.strftime("%Y%m%d-%H%M%S") |
| lines.append(f"TIME={time_str}") |
| lines.append(f"STEP={log_step}") |
|
|
| |
| num_uniques = len(set(docid_df.docid)) |
| unique_ratio = num_uniques / len(docid_df) |
| lines.append( |
| f"num_uniques: {num_uniques}/{len(docid_df)} ({unique_ratio*100:.2f}%)" |
| ) |
| lines.append("[Frequent Collision]") |
| lines.append(docid_df.docid.value_counts()[:5].to_string()) |
| num_unique_tokens = [0] * 10 |
| docid2num_docs = dict(docid_df.docid.value_counts()) |
| for docid in docid_df.docid.unique(): |
| tokens = docid.split("<->") |
| num_unique_token = len(set(tokens)) |
| num_unique_tokens[num_unique_token - 1] += 1 |
| lines.append(f"distribution of number of unique tokens: {num_unique_tokens}") |
|
|
| |
| docid2oldids = defaultdict(list) |
| oldid2docid_logit = dict() |
| for docid, oldid, docid_logit in docid_df[ |
| ["docid", "oldid", "docid_logit"] |
| ].values: |
| docid2oldids[docid].append(oldid) |
| oldid2docid_logit[oldid] = torch.tensor( |
| [float(x) for x in docid_logit.split("<->")] |
| ) |
|
|
| |
| oldid2docid = dict(zip(docid_df.oldid, docid_df.docid)) |
| if self.model_args.tree == 1: |
| builder = TreeBuilder() |
| all_id = [] |
| for docid in list(oldid2docid.values()): |
| toks = docid.split("<->") |
| toks = self.tokenizer.convert_tokens_to_ids(toks) |
| if len(toks) != self.model_args.max_output_length - 1: |
| print( |
| f"length of docid of {docid} is not {self.model_args.max_output_length}" |
| ) |
| toks = toks[: self.model_args.max_output_length - 1] |
| all_id.append(toks) |
| builder.add(toks) |
| self.root = builder.build() |
|
|
| self.docid2num_docs = docid2num_docs |
| self.docid2oldids = docid2oldids |
| self.oldid2docid_logit = oldid2docid_logit |
| self.oldid2docid = oldid2docid |
|
|
| if log_file is not None: |
| with open(log_file, "a") as f: |
| f.write("\n".join(lines) + "\n") |
| else: |
| print("\n".join(lines)) |
|
|
| def _dist_gather_tensor(self, t: Optional[torch.Tensor]): |
| if t is None: |
| return None |
| t = t.contiguous() |
|
|
| all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] |
| dist.all_gather(all_tensors, t) |
|
|
| all_tensors[self.process_rank] = t |
| all_tensors = torch.cat(all_tensors, dim=0) |
|
|
| return all_tensors |
|
|
| @classmethod |
| def build( |
| cls, |
| model_args: ModelArguments, |
| train_args: TrainingArguments, |
| tokenizer=None, |
| **hf_kwargs, |
| ): |
| cls.config = hf_kwargs["config"] |
| cls.softmax_temperature = model_args.softmax_temperature |
| cls.model_args = model_args |
|
|
| |
| if os.path.isdir(model_args.model_name_or_path): |
| if model_args.untie_encoder: |
| _qry_model_path = os.path.join( |
| model_args.model_name_or_path, "query_model" |
| ) |
| _psg_model_path = os.path.join( |
| model_args.model_name_or_path, "passage_model" |
| ) |
| if not os.path.exists(_qry_model_path): |
| _qry_model_path = model_args.model_name_or_path |
| _psg_model_path = model_args.model_name_or_path |
| logger.info(f"loading query model weight from {_qry_model_path}") |
| lm_q = cls.TRANSFORMER_CLS.from_pretrained(_qry_model_path, **hf_kwargs) |
| logger.info(f"loading passage model weight from {_psg_model_path}") |
| lm_p = cls.TRANSFORMER_CLS.from_pretrained(_psg_model_path, **hf_kwargs) |
| else: |
| lm_q = cls.TRANSFORMER_CLS.from_pretrained( |
| model_args.model_name_or_path, **hf_kwargs |
| ) |
| lm_p = lm_q |
| |
| else: |
| lm_q = cls.TRANSFORMER_CLS.from_pretrained( |
| model_args.model_name_or_path, **hf_kwargs |
| ) |
| lm_p = copy.deepcopy(lm_q) if model_args.untie_encoder else lm_q |
|
|
| model = cls( |
| lm_q=lm_q, |
| lm_p=lm_p, |
| negatives_x_device=train_args.negatives_x_device, |
| untie_encoder=model_args.untie_encoder, |
| tokenizer=tokenizer, |
| ) |
| return model |
|
|
| @classmethod |
| def load( |
| cls, |
| model_args: ModelArguments, |
| tokenizer: Optional[AutoTokenizer] = None, |
| **hf_kwargs, |
| ): |
| cls.config = hf_kwargs["config"] |
| cls.softmax_temperature = model_args.softmax_temperature |
| cls.model_args = model_args |
|
|
| |
| untie_encoder = True |
| if os.path.isdir(model_args.model_name_or_path): |
| _qry_model_path = os.path.join(model_args.model_name_or_path, "query_model") |
| _psg_model_path = os.path.join( |
| model_args.model_name_or_path, "passage_model" |
| ) |
| if os.path.exists(_qry_model_path): |
| logger.info(f"found separate weight for query/passage encoders") |
| logger.info(f"loading query model weight from {_qry_model_path}") |
| lm_q = cls.TRANSFORMER_CLS.from_pretrained(_qry_model_path, **hf_kwargs) |
| logger.info(f"loading passage model weight from {_psg_model_path}") |
| lm_p = cls.TRANSFORMER_CLS.from_pretrained(_psg_model_path, **hf_kwargs) |
| untie_encoder = False |
| else: |
| logger.info(f"try loading tied weight") |
| logger.info( |
| f"loading model weight from {model_args.model_name_or_path}" |
| ) |
| lm_q = cls.TRANSFORMER_CLS.from_pretrained( |
| model_args.model_name_or_path, **hf_kwargs |
| ) |
| lm_p = lm_q |
| else: |
| logger.info(f"try loading tied weight") |
| logger.info(f"loading model weight from {model_args.model_name_or_path}") |
| lm_q = cls.TRANSFORMER_CLS.from_pretrained( |
| model_args.model_name_or_path, **hf_kwargs |
| ) |
| lm_p = lm_q |
|
|
| model = cls(lm_q=lm_q, lm_p=lm_p, untie_encoder=untie_encoder) |
| return model |
|
|
| def save(self, output_dir: str): |
| if self.untie_encoder: |
| os.makedirs(os.path.join(output_dir, "query_model")) |
| os.makedirs(os.path.join(output_dir, "passage_model")) |
| self.lm_q.save_pretrained(os.path.join(output_dir, "query_model")) |
| self.lm_p.save_pretrained(os.path.join(output_dir, "passage_model")) |
| else: |
| self.lm_q.save_pretrained(output_dir) |
|
|