GLEN-model / src /tevatron /modeling /glen_phase2.py
QuanTH02's picture
Phase 1
3d5551b
import torch
import logging
import numpy as np
from torch.nn import functional as F
from typing import Dict, Optional
from tevatron.tree import dec_2d
from tevatron.modeling import EncoderDecoderModelForSeq2SeqLM
logger = logging.getLogger(__name__)
class GLENP2Model(EncoderDecoderModelForSeq2SeqLM):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Automatic device detection
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_device(self):
"""Get the device of the model parameters"""
return next(self.lm_p.parameters()).device
def encode_passage(self, psg: Optional[Dict]):
"""Encode passage"""
if psg is None:
return None, None, None
past_key_values, encoder_outputs = None, None
decoder_inputs_embeds = self.lm_p.get_input_embeddings()(
psg["decoder_input_ids"]
)
decoder_attention_mask_full = torch.ones(
psg["input_ids"].shape[0],
self.model_args.num_multi_vectors,
dtype=torch.long,
device=decoder_inputs_embeds.device,
)
p_reps = []
for i in range(self.model_args.num_multi_vectors):
decoder_attention_mask = decoder_attention_mask_full[:, : i + 1]
psg_out = self.lm_p(
input_ids=psg["input_ids"],
attention_mask=psg["attention_mask"],
decoder_inputs_embeds=decoder_inputs_embeds,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
encoder_outputs=encoder_outputs,
output_hidden_states=True,
use_cache=True,
past_key_values=past_key_values,
)
if encoder_outputs is None:
encoder_outputs = psg_out.encoder_hidden_states
past_key_values = psg_out.past_key_values
decoder_inputs_embeds = psg_out.decoder_hidden_states[-1][:, -1:, :]
p_reps.append(psg_out.decoder_hidden_states[-1][:, -1:, :])
p_reps = torch.cat(p_reps, dim=1) # (B * train_n_passages, M, H)
p_reps = p_reps * (p_reps.size(-1) ** -0.5)
p_reps_dt = p_reps.clone()
lm_logits = p_reps @ self.lm_p.shared.weight.T
if self.model_args.mask_special_tokens_for_decoding:
special_token_ids = self.tokenizer.all_special_ids
special_token_ids = [
x
for x in special_token_ids
if x
not in [
self.tokenizer.bos_token_id,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
]
]
lm_logits[:, :, special_token_ids] = -float("inf")
if self.model_args.do_docid_temperature_annealing:
first_temperature = self.model_args.docid_temperature
cur_epoch = (
self.trainer.state.epoch if hasattr(self, "trainer") else self.cur_epoch
)
temperature = max(
self.model_args.docid_temperature_min,
first_temperature * np.exp(-cur_epoch),
)
else:
temperature = self.model_args.docid_temperature
lm_logits = lm_logits / temperature
lm_attention = torch.softmax(lm_logits, dim=-1)
p_reps = lm_attention @ self.lm_p.shared.weight
return p_reps, lm_attention, p_reps_dt
def encode_query(self, qry: Optional[Dict]):
"""Encode query"""
if qry is None:
return None
past_key_values, encoder_outputs = None, None
decoder_inputs_embeds = self.lm_q.get_input_embeddings()(
qry["decoder_input_ids"]
)
decoder_attention_mask_full = torch.ones(
qry["input_ids"].shape[0],
self.model_args.num_multi_vectors,
dtype=torch.long,
device=decoder_inputs_embeds.device,
)
q_reps = []
for i in range(self.model_args.num_multi_vectors):
decoder_attention_mask = decoder_attention_mask_full[:, : i + 1]
qry_out = self.lm_q(
input_ids=qry["input_ids"],
attention_mask=qry["attention_mask"],
decoder_inputs_embeds=decoder_inputs_embeds,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
encoder_outputs=encoder_outputs,
output_hidden_states=True,
use_cache=True,
past_key_values=past_key_values,
)
if encoder_outputs is None:
encoder_outputs = qry_out.encoder_hidden_states
past_key_values = qry_out.past_key_values
decoder_inputs_embeds = qry_out.decoder_hidden_states[-1][:, -1:, :]
q_reps.append(qry_out.decoder_hidden_states[-1][:, -1:, :])
q_reps = torch.cat(q_reps, dim=1) # (B * train_n_passages, M, H)
q_reps = q_reps * (q_reps.size(-1) ** -0.5)
return q_reps
def compute_similarity(self, q_reps: torch.Tensor, p_reps: torch.Tensor):
"""Compute similarity between query and passage"""
# q_reps: (B, M, H), p_reps: (B * train_n_passages, M, H)
q_reps = q_reps.permute(1, 0, 2) # (M, B, H)
p_reps = p_reps.permute(1, 0, 2) # (M, B * train_n_passages, H)
scores = torch.bmm(q_reps, p_reps.permute(0, 2, 1)).permute(
1, 0, 2
) # (M, B, B*train_n_passages) -> (B, M, B*train_n_passages)
scores = scores.sum(dim=1) / self.model_args.num_multi_vectors
return scores
def make_doc_id(self, batch: Dict):
"""Process document and make doc_id"""
self.eval()
with torch.no_grad():
past_key_values, encoder_outputs = None, None
device = self.get_device()
decoder_inputs_embeds = self.lm_p.get_input_embeddings()(
torch.tensor([0], dtype=torch.long, device=device)
) # [1, H]
decoder_inputs_embeds = decoder_inputs_embeds.unsqueeze(0).repeat(
batch["source_ids"].shape[0], 1, 1
) # [B, 1, H]
decoder_attention_mask_full = torch.ones(
batch["source_ids"].shape[0],
self.model_args.max_output_length - 1,
dtype=torch.long,
device=device,
)
outs, out_logits = [], []
for i in range(self.model_args.max_output_length - 1):
decoder_attention_mask = decoder_attention_mask_full[:, : i + 1]
psg_out = self.lm_p(
input_ids=batch["source_ids"].to(device),
attention_mask=batch["source_mask"].to(device),
decoder_inputs_embeds=decoder_inputs_embeds,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
encoder_outputs=encoder_outputs,
output_hidden_states=True,
use_cache=True,
past_key_values=past_key_values,
)
if encoder_outputs is None:
encoder_outputs = psg_out.encoder_hidden_states
past_key_values = psg_out.past_key_values
decoder_inputs_embeds = psg_out.decoder_hidden_states[-1][:, -1:, :]
if self.model_args.mask_special_tokens_for_decoding:
psg_out.logits[:, :, self.model_args.special_token_ids] = -float(
"inf"
)
out = psg_out.logits[:, -1, :].argmax(dim=-1) # [B, L-1]
out_logit = (
psg_out.logits[:, -1, :].gather(1, out.unsqueeze(-1)).squeeze(-1)
)
outs.append(out.cpu().numpy())
out_logits.append(out_logit.cpu().detach().numpy())
outs = np.stack(outs, axis=1)
out_logits = np.stack(out_logits, axis=1)
preds = []
for ids in outs:
preds.append(
"<->".join(
self.tokenizer.convert_ids_to_tokens(
ids, skip_special_tokens=False
)
)
)
texts = [
self.tokenizer.decode(ids.numpy(), skip_special_tokens=True)
for ids in batch["source_ids"]
]
oldids = [oldid for oldid in batch["oldids"]]
out_logits = np.round(out_logits.astype(np.float64), 4)
return oldids, texts, preds, out_logits
def evaluation_step(self, batch: Dict, predefined_id: bool = False):
"""
Evaluate the model on a batch when validation during training.
When inference after training, predefined_id is False.
When validation during training, predefined_id is True.
Args:
batch (Dict):
batch data (source_ids, source_mask, aug_source_ids, aug_source_mask, target_ids, target_mask, rank, oldid)
predefined_id (bool):
whether ground truth query exists (True: validation during training, False: inference after training)
If decoder_input doc_rep, logits are used for breaking ties
Returns:
texts (List[str]):
list of source texts
preds (List[str]):
list of predicted query
target_docids (List[str]):
list of target docids
ranks (List[str]):
list of ranks
"""
# TODO Merge code with GLEN P1 evaluation_step
# If ground truth id exists, additional digit is appended to the end of the id
max_output_length = (
self.model_args.max_output_length
if predefined_id
else self.model_args.max_output_length - 1
)
num_return_sequences = self.model_args.num_return_sequences
with torch.no_grad():
past_key_values, encoder_outputs = None, None
K = num_return_sequences * 2
device = self.get_device()
decoder_inputs_embeds = self.lm_p.get_input_embeddings()(
torch.tensor([0], dtype=torch.long, device=device)
) # [1, H]
decoder_inputs_embeds = decoder_inputs_embeds.unsqueeze(0).repeat(
batch["source_ids"].shape[0], 1, 1
) # [B, 1, H]
decoder_attention_mask_full = torch.ones(
batch["source_ids"].shape[0],
self.model_args.max_output_length - 1,
dtype=torch.long,
device=device,
)
decode_tree = self.root if self.model_args.tree else None
outs, out_logits, all_next_token_logits = [], [], []
for i in range(max_output_length):
decoder_attention_mask = decoder_attention_mask_full[:, : i + 1]
psg_out = self.lm_p(
input_ids=batch["source_ids"].to(device),
attention_mask=batch["source_mask"].to(device),
decoder_inputs_embeds=decoder_inputs_embeds,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
encoder_outputs=encoder_outputs,
output_hidden_states=True,
use_cache=True,
past_key_values=past_key_values,
)
if encoder_outputs is None:
encoder_outputs = psg_out.encoder_hidden_states
past_key_values = psg_out.past_key_values
decoder_inputs_embeds = psg_out.decoder_hidden_states[-1][:, -1:, :]
if (
not predefined_id
and self.model_args.mask_special_tokens_for_decoding
):
psg_out.logits[:, :, self.model_args.special_token_ids] = -float(
"inf"
)
next_token_logits = psg_out.logits[:, -1, :] # [B, V]
all_next_token_logits.append(
next_token_logits.detach().cpu().unsqueeze(1)
) # [B, 1, V]
batch_size, vocab_size = next_token_logits.shape
scores = F.log_softmax(next_token_logits, dim=-1) # [B, V]
if i == 0:
mask = torch.ones_like(scores) * -float("inf") # [B, V]
candidates = list(decode_tree.children.keys())
mask[:, candidates] = 0 # [B, V]
scores += mask # [B, V]
out_prob, out_index = scores.topk(
K, dim=-1
) ## WARNING top k is doubled
cur_prob, cur_index = out_prob, out_index
all_index = cur_index.unsqueeze(-1) # [B, K, 1]
else:
out_prob = scores
all_prob = cur_prob[:, :, None] + out_prob.unsqueeze(1) # [B, K, V]
mask = torch.ones_like(all_prob) * -float("inf")
for b in range(batch_size):
for k in range(K):
previous_index = all_index[b, k, :].tolist()
cur = decode_tree
for value in previous_index:
if value not in cur.children:
next_candidates = [1] # eos token
break
else:
cur = cur.children[value]
else:
next_candidates = list(cur.children.keys())
mask[b, k, next_candidates] = 0
all_prob += mask
top_prob, top_index = all_prob.view(batch_size, -1).topk(
K, dim=-1
) # [B, num_return_sequences]
before_index = torch.div(
top_index, vocab_size, rounding_mode="floor"
)
after_index = top_index % vocab_size
before_index = all_index.gather(
1, before_index.unsqueeze(-1).repeat(1, 1, all_index.shape[-1])
)
all_index = torch.cat(
[before_index, after_index.unsqueeze(-1)], dim=-1
)
cur_prob = top_prob
all_index = all_index[:, :num_return_sequences, :] # [B, K//2, L]
outs = all_index.reshape(
-1, max_output_length
) # [B*num_return_sequences, L]
outs = outs.detach().cpu()
if predefined_id:
dec = []
for out in outs:
if self.tokenizer.eos_token_id in out:
out = out[: list(out).index(self.tokenizer.eos_token_id)]
dec.append(
"-".join(
self.tokenizer.convert_ids_to_tokens(
out, skip_special_tokens=True
)
)
)
else:
# Calculate logits for breaking ties
out_logits = torch.zeros_like(
outs, dtype=torch.float32
) # [B*num_return_sequences, L]
all_next_token_logits = (
torch.cat(all_next_token_logits, dim=1).detach().cpu()
) # [B, L, V]
for b in range(batch_size):
outs_b = outs[
b * num_return_sequences : (b + 1) * num_return_sequences
] # [num_return_sequences, L]
all_next_token_logits_b = all_next_token_logits[b : b + 1].repeat(
num_return_sequences, 1, 1
) # [num_return_sequences, L, V]
out_logits_b = all_next_token_logits_b.gather(
2, outs_b.unsqueeze(-1)
).squeeze(
-1
) # [num_return_sequences, L]
out_logits[
b * num_return_sequences : (b + 1) * num_return_sequences
] = out_logits_b
dec, temp_dec = [], []
for out_cnt, ids in enumerate(outs):
pred_id = "<->".join(
self.tokenizer.convert_ids_to_tokens(
ids, skip_special_tokens=False
)
)
out_logits_i = out_logits[out_cnt]
if self.docid2num_docs[pred_id] == 1:
oldid = self.docid2oldids[pred_id][0]
temp_dec.append(oldid + "<->" + pred_id)
else:
if self.model_args.reranking == "random":
cand_oldids = self.docid2oldids[pred_id] # list of oldids
cand_oldids = np.random.permutation(cand_oldids)
pred_id_list = [
oldid + "<->" + pred_id for oldid in cand_oldids
]
temp_dec += pred_id_list
elif self.model_args.reranking in ["cosine", "mse"]:
cand_oldids = self.docid2oldids[pred_id] # list of oldids
rank_scores = []
for cand_oldid in cand_oldids:
docid_logit = self.oldid2docid_logit[cand_oldid] # [L]
if self.model_args.reranking == "cosine":
rank_score = torch.cosine_similarity(
docid_logit.unsqueeze(0),
out_logits_i.unsqueeze(0),
).item()
elif self.model_args.reranking == "mse":
rank_score = (
(docid_logit - out_logits_i)
.pow(2)
.mean()
.item()
)
rank_scores.append(rank_score)
rank_scores = np.array(rank_scores)
order_index = np.argsort(-rank_scores)
ordered_cand_oldids = np.array(cand_oldids)[order_index]
pred_id_list = [
oldid + "<->" + pred_id for oldid in ordered_cand_oldids
]
temp_dec += pred_id_list
else:
raise NotImplementedError(
"reranking method should be one of ['random', 'cosine', 'mse']"
)
if out_cnt % num_return_sequences == num_return_sequences - 1:
temp_dec = temp_dec[:num_return_sequences]
dec += temp_dec
temp_dec = []
targets = []
for oldid in batch["oldid"]:
targets.append(oldid + "<->" + self.oldid2docid[oldid])
batch["rank"][0][0] = targets
texts = [
self.tokenizer.decode(ids, skip_special_tokens=True)
for ids in batch["source_ids"]
]
preds = dec_2d(dec, num_return_sequences)
target_docids = batch["rank"][0][0]
ranks = [str(a.item()) for a in batch["rank"][0][1]]
return texts, preds, target_docids, ranks