import numpy as np import torch import torch.nn.functional as F def add_gumbel_noise(logits, temperature): ''' The Gumbel max is a method for sampling categorical distributions. According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. Thus, we use float64. ''' if temperature == 0: return logits logits = logits.to(torch.float64) noise = torch.rand_like(logits, dtype=torch.float64) gumbel_noise = (- torch.log(noise)) ** temperature return logits.exp() / gumbel_noise def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, neg_entropy=False): logits_with_noise = add_gumbel_noise(logits, temperature=temperature) x0 = torch.argmax(logits_with_noise, dim=-1) if remasking == 'low_confidence': # p = F.softmax(logits.to(torch.float64), dim=-1) p = F.softmax(logits, dim=-1) x0_p = torch.squeeze( torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l elif remasking == 'top_p_margin': # Compute probabilities p = F.softmax(logits, dim=-1) # (B, L, V) # Top-2 per position top2 = torch.topk(p, k=2, dim=-1).values # (B, L, 2) margin = top2[..., 0] - top2[..., 1] # (B, L) # Normalize margin to [0,1] over MASKED positions per row plus_inf = torch.full_like(margin, float('inf')) minus_inf = torch.full_like(margin, float('-inf')) masked_for_min = torch.where(mask_index, margin, plus_inf) masked_for_max = torch.where(mask_index, margin, minus_inf) row_min = masked_for_min.amin(dim=1, keepdim=True) # (B, 1) row_max = masked_for_max.amax(dim=1, keepdim=True) # (B, 1) denom = (row_max - row_min) # If denom==0 (all equal), set normalized=1 on masked; 0 elsewhere by default normalized = torch.zeros_like(margin) nonzero = denom > 0 normalized = torch.where( mask_index & nonzero, (margin - row_min) / (denom + 1e-12), normalized ) normalized = torch.where( mask_index & (~nonzero), torch.ones_like(normalized), normalized ) x0_p = normalized # ∈ [0,1] on masked positions elif remasking == 'random': x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) else: raise NotImplementedError(remasking) # Calculate negative entropy if requested if neg_entropy: # p = F.softmax(logits.to(torch.float64), dim=-1) p = F.softmax(logits, dim=-1) epsilon = 1e-10 log_probs = torch.log(p + epsilon) confidence_scores = torch.sum(p * log_probs, dim=-1) # negative entropy per position else: confidence_scores = x0_p x0 = torch.where(mask_index, x0, x) confidence = torch.where(mask_index, confidence_scores, -np.inf) transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) if threshold is not None: num_transfer_tokens = mask_index.sum(dim=1, keepdim=True) # print(f'confidence: {confidence}') for j in range(confidence.shape[0]): _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j]) transfer_index[j, select_index] = True if threshold is not None: for k in range(1, num_transfer_tokens[j]): if confidence[j, select_index[k]] < threshold: transfer_index[j, select_index[k]] = False return x0, transfer_index def get_num_transfer_tokens(mask_index, steps: int): mask_num = mask_index.sum(dim=1, keepdim=True) base = mask_num // steps remainder = mask_num % steps num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base for i in range(mask_num.size(0)): num_transfer_tokens[i, : int(remainder[i])] += 1 return num_transfer_tokens @torch.no_grad() def generate_with_prefix_cache_block_diff( model, prompt, steps=128, gen_length=128, block_length=128, temperature=0., remasking='low_confidence', mask_id=126336, threshold=None, factor=None, shift_logits=False, neg_entropy=False, causal_context=False, pixel_values=None, image_sizes=None, eos_token_id=None, ): dream_style=shift_logits x_accum = prompt.clone() assert gen_length % block_length == 0 num_blocks = gen_length // block_length assert steps % num_blocks == 0 steps_per_block = steps // num_blocks nfe = 0 if causal_context: model_module = model.module if hasattr(model, "module") else model for layer in model_module.encoder.layers: if hasattr(layer.self_attn, 'diffusion_lm'): layer.self_attn.diffusion_lm=False # Compute KV cache for the prompt initially # Pass pixel_values/image_sizes only for this first call (prompt contains image tokens) output = model(prompt, use_cache=True, use_causal_mask=causal_context, pixel_values=pixel_values, image_sizes=image_sizes) past_key_values = output.past_key_values if causal_context: for layer in model_module.encoder.layers: if hasattr(layer.self_attn, 'diffusion_lm'): layer.self_attn.diffusion_lm=True # For dream_style: store the "next token logit" of the context next_logits_context = None if dream_style: next_logits_context = output.logits[:, -1:, :] # (B, 1, V) for num_block in range(num_blocks): # Create a new block with mask tokens (no seeding) mask_block = torch.ones( (prompt.shape[0], block_length), dtype=prompt.dtype, device=prompt.device ) * mask_id # Append the block of masks x_accum = torch.cat([x_accum, mask_block], dim=1) current_block_start = prompt.size(1) + num_block * block_length block_slice = slice(current_block_start, current_block_start + block_length) # Build the initial mask for this block mask_block_idx0 = (x_accum[:, block_slice] == mask_id) # (B, Lb) # Precompute the transfer schedule for this block if dream_style: # still denoise *all* positions (0..Lb-1), since none are seeded schedule_mask = mask_block_idx0 else: schedule_mask = mask_block_idx0 num_transfer_tokens = get_num_transfer_tokens(schedule_mask, steps_per_block) # (B, steps) # Denoise the current block for i in range(steps_per_block): mask_block_idx = (x_accum[:, block_slice] == mask_id) # (B, Lb) if mask_block_idx.sum() == 0: break nfe += 1 # Forward only the current noisy block using cached context logits_block = model( x_accum[:, block_slice], past_key_values=past_key_values, use_cache=False ).logits if dream_style: # Align logits so that each masked position has a predictor: # prepend context-next logit, then use logits_block[:-1] if block_length == 1: logits_use = next_logits_context # (B, 1, V) else: logits_use = torch.cat( [next_logits_context, logits_block[:, :-1, :]], dim=1 ) # (B, Lb, V) mask_use = mask_block_idx # (B, Lb) x_use = x_accum[:, block_slice] # (B, Lb) x0, transfer_idx = get_transfer_index( logits_use, temperature, remasking, mask_use, x_use, num_transfer_tokens=num_transfer_tokens[:, i], threshold=threshold, neg_entropy=neg_entropy ) cur = x_accum[:, block_slice].clone() cur[transfer_idx] = x0[transfer_idx] x_accum[:, block_slice] = cur else: # non-AR (same-position) case x0, transfer_idx = get_transfer_index( logits_block, temperature, remasking, mask_block_idx, x_accum[:, block_slice], num_transfer_tokens=num_transfer_tokens[:, i], threshold=threshold, neg_entropy=neg_entropy ) cur = x_accum[:, block_slice].clone() cur[transfer_idx] = x0[transfer_idx] x_accum[:, block_slice] = cur if eos_token_id is not None: block_tokens = x_accum[:, block_slice] # (B, Lb) eos_mask = (block_tokens == eos_token_id) # (B, Lb) any_eos = eos_mask.any(dim=1) # (B,) if any_eos.any(): after_eos = eos_mask.cumsum(dim=1).bool() # (B, Lb) mask_before = (block_tokens == mask_id) & ~after_eos if (any_eos & ~mask_before.any(dim=1)).any(): break if causal_context: for layer in model_module.encoder.layers: if hasattr(layer.self_attn, 'diffusion_lm'): layer.self_attn.diffusion_lm=False # after block is fully denoised, update KV cache output = model( x_accum[:, block_slice], past_key_values=past_key_values, use_cache=True, use_causal_mask=causal_context ) past_key_values = output.past_key_values if causal_context: for layer in model_module.encoder.layers: if hasattr(layer.self_attn, 'diffusion_lm'): layer.self_attn.diffusion_lm=True if dream_style and num_block < num_blocks - 1: # refresh context-next logit for the next block next_logits_context = output.logits[:, -1:, :] # (B, 1, V) if eos_token_id is not None: gen_so_far = x_accum[:, prompt.size(1):] # (B, gen_len_so_far) is_eos = (gen_so_far == eos_token_id) # (B, gen_len_so_far) has_eos = is_eos.any(dim=1) # (B,) if has_eos.all(): return x_accum, nfe # first_eos_pos = is_eos.to(torch.int64).argmax(dim=1) # (B,) # max_eos = first_eos_pos.max().item() # return x_accum[:, : prompt.size(1) + max_eos + 1], nfe return x_accum, nfe