Spaces:
Running on Zero
Running on Zero
File size: 2,351 Bytes
0afe769 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | import torch
import torch.nn as nn
import torch.nn.functional as F
def interp_table(t: torch.Tensor, table: torch.Tensor) -> torch.Tensor:
T = table.shape[0]
t = t.clamp(0.0, 1.0)
pos = t * (T - 1)
idx0 = pos.floor().long()
idx1 = (idx0 + 1).clamp(max=T - 1)
w = pos - idx0.float()
y0 = table[idx0]
y1 = table[idx1]
return y0 * (1 - w) + y1 * w
def random_mask(token_lengths, min_ratio=0., max_ratio=0.3):
b = token_lengths.shape[0]
r = torch.rand(b).to(token_lengths.device) * (max_ratio - min_ratio) + min_ratio
mask_start_ids = (token_lengths * r).round().int()
return mask_start_ids
def sequence_mask(length: torch.Tensor, max_length: int = None, left_padded: bool=False) -> torch.Tensor:
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
if left_padded:
return x.unsqueeze(0) >= (max_length - length).unsqueeze(1)
else:
return x.unsqueeze(0) < length.unsqueeze(1)
def pad_nested_tensor(nested_tensor, padding_value=0, left_padded=False):
if left_padded:
reversed_sequences = [seq.flip(dims=[0]) for seq in nested_tensor.unbind()]
reversed_nested_tensor = torch.nested.nested_tensor(reversed_sequences)
padded_tensor = torch.nested.to_padded_tensor(reversed_nested_tensor, padding=padding_value)
return padded_tensor.flip(dims=[1])
else:
return torch.nested.to_padded_tensor(nested_tensor, padding=padding_value)
def logits_top_p(logits, top_p):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, float('-inf'))
return logits
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -torch.log(-torch.log(noise + 1e-6) + 1e-6)
def gumbel_sample(t, dim=-1):
return (t + gumbel_noise(t)).argmax(dim=dim) |