Efficient-DLM-4B / chat_utils.py
YongganFu's picture
Initial release of Efficient-DLM-4B
d28c316
import numpy as np
import torch
import torch.nn.functional as F
def add_gumbel_noise(logits, temperature):
'''
The Gumbel max is a method for sampling categorical distributions.
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
Thus, we use float64.
'''
if temperature == 0:
return logits
logits = logits.to(torch.float64)
noise = torch.rand_like(logits, dtype=torch.float64)
gumbel_noise = (- torch.log(noise)) ** temperature
return logits.exp() / gumbel_noise
def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, neg_entropy=False):
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
x0 = torch.argmax(logits_with_noise, dim=-1)
if remasking == 'low_confidence':
# p = F.softmax(logits.to(torch.float64), dim=-1)
p = F.softmax(logits, dim=-1)
x0_p = torch.squeeze(
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
elif remasking == 'top_p_margin':
# Compute probabilities
p = F.softmax(logits, dim=-1) # (B, L, V)
# Top-2 per position
top2 = torch.topk(p, k=2, dim=-1).values # (B, L, 2)
margin = top2[..., 0] - top2[..., 1] # (B, L)
# Normalize margin to [0,1] over MASKED positions per row
plus_inf = torch.full_like(margin, float('inf'))
minus_inf = torch.full_like(margin, float('-inf'))
masked_for_min = torch.where(mask_index, margin, plus_inf)
masked_for_max = torch.where(mask_index, margin, minus_inf)
row_min = masked_for_min.amin(dim=1, keepdim=True) # (B, 1)
row_max = masked_for_max.amax(dim=1, keepdim=True) # (B, 1)
denom = (row_max - row_min)
# If denom==0 (all equal), set normalized=1 on masked; 0 elsewhere by default
normalized = torch.zeros_like(margin)
nonzero = denom > 0
normalized = torch.where(
mask_index & nonzero,
(margin - row_min) / (denom + 1e-12),
normalized
)
normalized = torch.where(
mask_index & (~nonzero),
torch.ones_like(normalized),
normalized
)
x0_p = normalized # ∈ [0,1] on masked positions
elif remasking == 'random':
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
else:
raise NotImplementedError(remasking)
# Calculate negative entropy if requested
if neg_entropy:
# p = F.softmax(logits.to(torch.float64), dim=-1)
p = F.softmax(logits, dim=-1)
epsilon = 1e-10
log_probs = torch.log(p + epsilon)
confidence_scores = torch.sum(p * log_probs, dim=-1) # negative entropy per position
else:
confidence_scores = x0_p
x0 = torch.where(mask_index, x0, x)
confidence = torch.where(mask_index, confidence_scores, -np.inf)
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
if threshold is not None:
num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
# print(f'confidence: {confidence}')
for j in range(confidence.shape[0]):
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j])
transfer_index[j, select_index] = True
if threshold is not None:
for k in range(1, num_transfer_tokens[j]):
if confidence[j, select_index[k]] < threshold:
transfer_index[j, select_index[k]] = False
return x0, transfer_index
def get_num_transfer_tokens(mask_index, steps: int):
mask_num = mask_index.sum(dim=1, keepdim=True)
base = mask_num // steps
remainder = mask_num % steps
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
for i in range(mask_num.size(0)):
num_transfer_tokens[i, : int(remainder[i])] += 1
return num_transfer_tokens
@torch.no_grad()
def generate_with_prefix_cache_block_diff(
model,
prompt,
steps=128,
gen_length=128,
block_length=128,
temperature=0.,
remasking='low_confidence',
mask_id=126336,
threshold=None,
factor=None,
shift_logits=False,
neg_entropy=False,
):
dream_style=shift_logits
# Initialize the accumulator
x_accum = prompt.clone()
assert gen_length % block_length == 0
num_blocks = gen_length // block_length
assert steps % num_blocks == 0
steps_per_block = steps // num_blocks
nfe = 0
# Compute KV cache for the prompt initially
output = model(prompt, use_cache=True)
past_key_values = output.past_key_values
# For dream_style: store the "next token logit" of the context
next_logits_context = None
if dream_style:
next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
for num_block in range(num_blocks):
# Create a new block with mask tokens (no seeding)
mask_block = torch.ones(
(prompt.shape[0], block_length),
dtype=prompt.dtype,
device=prompt.device
) * mask_id
# Append the block of masks
x_accum = torch.cat([x_accum, mask_block], dim=1)
current_block_start = prompt.size(1) + num_block * block_length
block_slice = slice(current_block_start, current_block_start + block_length)
# Build the initial mask for this block
mask_block_idx0 = (x_accum[:, block_slice] == mask_id) # (B, Lb)
# Precompute the transfer schedule for this block
if dream_style:
# still denoise *all* positions (0..Lb-1), since none are seeded
schedule_mask = mask_block_idx0
else:
schedule_mask = mask_block_idx0
num_transfer_tokens = get_num_transfer_tokens(schedule_mask, steps_per_block) # (B, steps)
# Denoise the current block
for i in range(steps_per_block):
mask_block_idx = (x_accum[:, block_slice] == mask_id) # (B, Lb)
if mask_block_idx.sum() == 0:
break
nfe += 1
# Forward only the current noisy block using cached context
logits_block = model(
x_accum[:, block_slice],
past_key_values=past_key_values,
use_cache=False
).logits
if dream_style:
# Align logits so that each masked position has a predictor:
# prepend context-next logit, then use logits_block[:-1]
if block_length == 1:
logits_use = next_logits_context # (B, 1, V)
else:
logits_use = torch.cat(
[next_logits_context, logits_block[:, :-1, :]],
dim=1
) # (B, Lb, V)
mask_use = mask_block_idx # (B, Lb)
x_use = x_accum[:, block_slice] # (B, Lb)
x0, transfer_idx = get_transfer_index(
logits_use, temperature, remasking, mask_use, x_use,
num_transfer_tokens=num_transfer_tokens[:, i],
threshold=threshold, neg_entropy=neg_entropy
)
cur = x_accum[:, block_slice].clone()
cur[transfer_idx] = x0[transfer_idx]
x_accum[:, block_slice] = cur
else:
# non-AR (same-position) case
x0, transfer_idx = get_transfer_index(
logits_block, temperature, remasking, mask_block_idx,
x_accum[:, block_slice],
num_transfer_tokens=num_transfer_tokens[:, i],
threshold=threshold, neg_entropy=neg_entropy
)
cur = x_accum[:, block_slice].clone()
cur[transfer_idx] = x0[transfer_idx]
x_accum[:, block_slice] = cur
# after block is fully denoised, update KV cache
output = model(
x_accum[:, block_slice],
past_key_values=past_key_values,
use_cache=True
)
past_key_values = output.past_key_values
nfe += 1
if dream_style and num_block < num_blocks - 1:
# refresh context-next logit for the next block
next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
return x_accum, nfe