import torch import torch.nn as nn import torch.nn.functional as F def interp_table(t: torch.Tensor, table: torch.Tensor) -> torch.Tensor: T = table.shape[0] t = t.clamp(0.0, 1.0) pos = t * (T - 1) idx0 = pos.floor().long() idx1 = (idx0 + 1).clamp(max=T - 1) w = pos - idx0.float() y0 = table[idx0] y1 = table[idx1] return y0 * (1 - w) + y1 * w def random_mask(token_lengths, min_ratio=0., max_ratio=0.3): b = token_lengths.shape[0] r = torch.rand(b).to(token_lengths.device) * (max_ratio - min_ratio) + min_ratio mask_start_ids = (token_lengths * r).round().int() return mask_start_ids def sequence_mask(length: torch.Tensor, max_length: int = None, left_padded: bool=False) -> torch.Tensor: if max_length is None: max_length = length.max() x = torch.arange(max_length, dtype=length.dtype, device=length.device) if left_padded: return x.unsqueeze(0) >= (max_length - length).unsqueeze(1) else: return x.unsqueeze(0) < length.unsqueeze(1) def pad_nested_tensor(nested_tensor, padding_value=0, left_padded=False): if left_padded: reversed_sequences = [seq.flip(dims=[0]) for seq in nested_tensor.unbind()] reversed_nested_tensor = torch.nested.nested_tensor(reversed_sequences) padded_tensor = torch.nested.to_padded_tensor(reversed_nested_tensor, padding=padding_value) return padded_tensor.flip(dims=[1]) else: return torch.nested.to_padded_tensor(nested_tensor, padding=padding_value) def logits_top_p(logits, top_p): sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove) logits = logits.masked_fill(indices_to_remove, float('-inf')) return logits def gumbel_noise(t): noise = torch.zeros_like(t).uniform_(0, 1) return -torch.log(-torch.log(noise + 1e-6) + 1e-6) def gumbel_sample(t, dim=-1): return (t + gumbel_noise(t)).argmax(dim=dim)