import numpy as np import torch import torch.nn.functional as F def add_gumbel_noise(logits, temperature): ''' The Gumbel max is a method for sampling categorical distributions. According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. Thus, we use float64. ''' if temperature == 0: return logits logits = logits.to(torch.float64) noise = torch.rand_like(logits, dtype=torch.float64) gumbel_noise = (- torch.log(noise)) ** temperature return logits.exp() / gumbel_noise def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, neg_entropy=False): logits_with_noise = add_gumbel_noise(logits, temperature=temperature) x0 = torch.argmax(logits_with_noise, dim=-1) if remasking == 'low_confidence': # p = F.softmax(logits.to(torch.float64), dim=-1) p = F.softmax(logits, dim=-1) x0_p = torch.squeeze( torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l elif remasking == 'top_p_margin': # Compute probabilities p = F.softmax(logits, dim=-1) # (B, L, V) # Top-2 per position top2 = torch.topk(p, k=2, dim=-1).values # (B, L, 2) margin = top2[..., 0] - top2[..., 1] # (B, L) # Normalize margin to [0,1] over MASKED positions per row plus_inf = torch.full_like(margin, float('inf')) minus_inf = torch.full_like(margin, float('-inf')) masked_for_min = torch.where(mask_index, margin, plus_inf) masked_for_max = torch.where(mask_index, margin, minus_inf) row_min = masked_for_min.amin(dim=1, keepdim=True) # (B, 1) row_max = masked_for_max.amax(dim=1, keepdim=True) # (B, 1) denom = (row_max - row_min) # If denom==0 (all equal), set normalized=1 on masked; 0 elsewhere by default normalized = torch.zeros_like(margin) nonzero = denom > 0 normalized = torch.where( mask_index & nonzero, (margin - row_min) / (denom + 1e-12), normalized ) normalized = torch.where( mask_index & (~nonzero), torch.ones_like(normalized), normalized ) x0_p = normalized # ∈ [0,1] on masked positions elif remasking == 'random': x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) else: raise NotImplementedError(remasking) # Calculate negative entropy if requested if neg_entropy: # p = F.softmax(logits.to(torch.float64), dim=-1) p = F.softmax(logits, dim=-1) epsilon = 1e-10 log_probs = torch.log(p + epsilon) confidence_scores = torch.sum(p * log_probs, dim=-1) # negative entropy per position else: confidence_scores = x0_p x0 = torch.where(mask_index, x0, x) confidence = torch.where(mask_index, confidence_scores, -np.inf) transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) if threshold is not None: num_transfer_tokens = mask_index.sum(dim=1, keepdim=True) # print(f'confidence: {confidence}') for j in range(confidence.shape[0]): _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j]) transfer_index[j, select_index] = True if threshold is not None: for k in range(1, num_transfer_tokens[j]): if confidence[j, select_index[k]] < threshold: transfer_index[j, select_index[k]] = False return x0, transfer_index def get_num_transfer_tokens(mask_index, steps: int): mask_num = mask_index.sum(dim=1, keepdim=True) base = mask_num // steps remainder = mask_num % steps num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base for i in range(mask_num.size(0)): num_transfer_tokens[i, : int(remainder[i])] += 1 return num_transfer_tokens @torch.no_grad() def generate_with_prefix_cache_block_diff( model, prompt, steps=128, gen_length=128, block_length=128, temperature=0., remasking='low_confidence', mask_id=126336, threshold=None, factor=None, shift_logits=False, neg_entropy=False, ): dream_style=shift_logits # Initialize the accumulator x_accum = prompt.clone() assert gen_length % block_length == 0 num_blocks = gen_length // block_length assert steps % num_blocks == 0 steps_per_block = steps // num_blocks nfe = 0 # Compute KV cache for the prompt initially output = model(prompt, use_cache=True) past_key_values = output.past_key_values # For dream_style: store the "next token logit" of the context next_logits_context = None if dream_style: next_logits_context = output.logits[:, -1:, :] # (B, 1, V) for num_block in range(num_blocks): # Create a new block with mask tokens (no seeding) mask_block = torch.ones( (prompt.shape[0], block_length), dtype=prompt.dtype, device=prompt.device ) * mask_id # Append the block of masks x_accum = torch.cat([x_accum, mask_block], dim=1) current_block_start = prompt.size(1) + num_block * block_length block_slice = slice(current_block_start, current_block_start + block_length) # Build the initial mask for this block mask_block_idx0 = (x_accum[:, block_slice] == mask_id) # (B, Lb) # Precompute the transfer schedule for this block if dream_style: # still denoise *all* positions (0..Lb-1), since none are seeded schedule_mask = mask_block_idx0 else: schedule_mask = mask_block_idx0 num_transfer_tokens = get_num_transfer_tokens(schedule_mask, steps_per_block) # (B, steps) # Denoise the current block for i in range(steps_per_block): mask_block_idx = (x_accum[:, block_slice] == mask_id) # (B, Lb) if mask_block_idx.sum() == 0: break nfe += 1 # Forward only the current noisy block using cached context logits_block = model( x_accum[:, block_slice], past_key_values=past_key_values, use_cache=False ).logits if dream_style: # Align logits so that each masked position has a predictor: # prepend context-next logit, then use logits_block[:-1] if block_length == 1: logits_use = next_logits_context # (B, 1, V) else: logits_use = torch.cat( [next_logits_context, logits_block[:, :-1, :]], dim=1 ) # (B, Lb, V) mask_use = mask_block_idx # (B, Lb) x_use = x_accum[:, block_slice] # (B, Lb) x0, transfer_idx = get_transfer_index( logits_use, temperature, remasking, mask_use, x_use, num_transfer_tokens=num_transfer_tokens[:, i], threshold=threshold, neg_entropy=neg_entropy ) cur = x_accum[:, block_slice].clone() cur[transfer_idx] = x0[transfer_idx] x_accum[:, block_slice] = cur else: # non-AR (same-position) case x0, transfer_idx = get_transfer_index( logits_block, temperature, remasking, mask_block_idx, x_accum[:, block_slice], num_transfer_tokens=num_transfer_tokens[:, i], threshold=threshold, neg_entropy=neg_entropy ) cur = x_accum[:, block_slice].clone() cur[transfer_idx] = x0[transfer_idx] x_accum[:, block_slice] = cur # after block is fully denoised, update KV cache output = model( x_accum[:, block_slice], past_key_values=past_key_values, use_cache=True ) past_key_values = output.past_key_values nfe += 1 if dream_style and num_block < num_blocks - 1: # refresh context-next logit for the next block next_logits_context = output.logits[:, -1:, :] # (B, 1, V) return x_accum, nfe