File size: 2,351 Bytes
0afe769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
import torch.nn.functional as F

def interp_table(t: torch.Tensor, table: torch.Tensor) -> torch.Tensor:
    T = table.shape[0]

    t = t.clamp(0.0, 1.0)
    pos = t * (T - 1)                  
    idx0 = pos.floor().long()          
    idx1 = (idx0 + 1).clamp(max=T - 1) 
    w = pos - idx0.float()             

    y0 = table[idx0]
    y1 = table[idx1]
    return y0 * (1 - w) + y1 * w

def random_mask(token_lengths, min_ratio=0., max_ratio=0.3):
    b = token_lengths.shape[0]
    r = torch.rand(b).to(token_lengths.device) * (max_ratio - min_ratio) + min_ratio
    mask_start_ids = (token_lengths * r).round().int()
    return mask_start_ids

def sequence_mask(length: torch.Tensor, max_length: int = None, left_padded: bool=False) -> torch.Tensor:
    if max_length is None:
        max_length = length.max()
    x = torch.arange(max_length, dtype=length.dtype, device=length.device)

    if left_padded:
        return x.unsqueeze(0) >= (max_length - length).unsqueeze(1)
    else:
        return x.unsqueeze(0) < length.unsqueeze(1)

def pad_nested_tensor(nested_tensor, padding_value=0, left_padded=False):
    if left_padded:
        reversed_sequences = [seq.flip(dims=[0]) for seq in nested_tensor.unbind()]
        reversed_nested_tensor = torch.nested.nested_tensor(reversed_sequences)

        padded_tensor = torch.nested.to_padded_tensor(reversed_nested_tensor, padding=padding_value)
        return padded_tensor.flip(dims=[1])
    else:
        return torch.nested.to_padded_tensor(nested_tensor, padding=padding_value)
    
def logits_top_p(logits, top_p):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
    logits = logits.masked_fill(indices_to_remove, float('-inf'))
    return logits

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -torch.log(-torch.log(noise + 1e-6) + 1e-6)

def gumbel_sample(t, dim=-1):
    return (t + gumbel_noise(t)).argmax(dim=dim)