| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| def add_gumbel_noise(logits, temperature): |
| ''' |
| The Gumbel max is a method for sampling categorical distributions. |
| According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. |
| Thus, we use float64. |
| ''' |
| if temperature == 0: |
| return logits |
| logits = logits.to(torch.float64) |
| noise = torch.rand_like(logits, dtype=torch.float64) |
| gumbel_noise = (- torch.log(noise)) ** temperature |
| return logits.exp() / gumbel_noise |
|
|
| |
| def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, neg_entropy=False): |
| logits_with_noise = add_gumbel_noise(logits, temperature=temperature) |
| x0 = torch.argmax(logits_with_noise, dim=-1) |
| |
| if remasking == 'low_confidence': |
| |
| p = F.softmax(logits, dim=-1) |
| x0_p = torch.squeeze( |
| torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) |
| elif remasking == 'top_p_margin': |
| |
| p = F.softmax(logits, dim=-1) |
| |
| top2 = torch.topk(p, k=2, dim=-1).values |
| margin = top2[..., 0] - top2[..., 1] |
|
|
| |
| plus_inf = torch.full_like(margin, float('inf')) |
| minus_inf = torch.full_like(margin, float('-inf')) |
| masked_for_min = torch.where(mask_index, margin, plus_inf) |
| masked_for_max = torch.where(mask_index, margin, minus_inf) |
| row_min = masked_for_min.amin(dim=1, keepdim=True) |
| row_max = masked_for_max.amax(dim=1, keepdim=True) |
| denom = (row_max - row_min) |
|
|
| |
| normalized = torch.zeros_like(margin) |
| nonzero = denom > 0 |
| normalized = torch.where( |
| mask_index & nonzero, |
| (margin - row_min) / (denom + 1e-12), |
| normalized |
| ) |
| normalized = torch.where( |
| mask_index & (~nonzero), |
| torch.ones_like(normalized), |
| normalized |
| ) |
| x0_p = normalized |
| elif remasking == 'random': |
| x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) |
| else: |
| raise NotImplementedError(remasking) |
| |
| |
| if neg_entropy: |
| |
| p = F.softmax(logits, dim=-1) |
| epsilon = 1e-10 |
| log_probs = torch.log(p + epsilon) |
| confidence_scores = torch.sum(p * log_probs, dim=-1) |
| else: |
| confidence_scores = x0_p |
| |
| x0 = torch.where(mask_index, x0, x) |
| confidence = torch.where(mask_index, confidence_scores, -np.inf) |
|
|
| transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) |
| if threshold is not None: |
| num_transfer_tokens = mask_index.sum(dim=1, keepdim=True) |
| |
| for j in range(confidence.shape[0]): |
| _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j]) |
| transfer_index[j, select_index] = True |
| if threshold is not None: |
| for k in range(1, num_transfer_tokens[j]): |
| if confidence[j, select_index[k]] < threshold: |
| transfer_index[j, select_index[k]] = False |
| return x0, transfer_index |
|
|
|
|
| def get_num_transfer_tokens(mask_index, steps: int): |
| mask_num = mask_index.sum(dim=1, keepdim=True) |
| base = mask_num // steps |
| remainder = mask_num % steps |
| num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base |
| for i in range(mask_num.size(0)): |
| num_transfer_tokens[i, : int(remainder[i])] += 1 |
| return num_transfer_tokens |
|
|
|
|
| @torch.no_grad() |
| def generate_with_prefix_cache_block_diff( |
| model, |
| prompt, |
| steps=128, |
| gen_length=128, |
| block_length=128, |
| temperature=0., |
| remasking='low_confidence', |
| mask_id=126336, |
| threshold=None, |
| factor=None, |
| shift_logits=False, |
| neg_entropy=False, |
| ): |
| dream_style=shift_logits |
| |
| x_accum = prompt.clone() |
|
|
| assert gen_length % block_length == 0 |
| num_blocks = gen_length // block_length |
|
|
| assert steps % num_blocks == 0 |
| steps_per_block = steps // num_blocks |
|
|
| nfe = 0 |
|
|
| |
| output = model(prompt, use_cache=True) |
| past_key_values = output.past_key_values |
|
|
| |
| next_logits_context = None |
| if dream_style: |
| next_logits_context = output.logits[:, -1:, :] |
|
|
| for num_block in range(num_blocks): |
| |
| mask_block = torch.ones( |
| (prompt.shape[0], block_length), |
| dtype=prompt.dtype, |
| device=prompt.device |
| ) * mask_id |
|
|
| |
| x_accum = torch.cat([x_accum, mask_block], dim=1) |
| current_block_start = prompt.size(1) + num_block * block_length |
| block_slice = slice(current_block_start, current_block_start + block_length) |
|
|
| |
| mask_block_idx0 = (x_accum[:, block_slice] == mask_id) |
|
|
| |
| if dream_style: |
| |
| schedule_mask = mask_block_idx0 |
| else: |
| schedule_mask = mask_block_idx0 |
|
|
| num_transfer_tokens = get_num_transfer_tokens(schedule_mask, steps_per_block) |
|
|
| |
| for i in range(steps_per_block): |
| mask_block_idx = (x_accum[:, block_slice] == mask_id) |
| if mask_block_idx.sum() == 0: |
| break |
|
|
| nfe += 1 |
|
|
| |
| logits_block = model( |
| x_accum[:, block_slice], |
| past_key_values=past_key_values, |
| use_cache=False |
| ).logits |
|
|
| if dream_style: |
| |
| |
| if block_length == 1: |
| logits_use = next_logits_context |
| else: |
| logits_use = torch.cat( |
| [next_logits_context, logits_block[:, :-1, :]], |
| dim=1 |
| ) |
|
|
| mask_use = mask_block_idx |
| x_use = x_accum[:, block_slice] |
|
|
| x0, transfer_idx = get_transfer_index( |
| logits_use, temperature, remasking, mask_use, x_use, |
| num_transfer_tokens=num_transfer_tokens[:, i], |
| threshold=threshold, neg_entropy=neg_entropy |
| ) |
| cur = x_accum[:, block_slice].clone() |
| cur[transfer_idx] = x0[transfer_idx] |
| x_accum[:, block_slice] = cur |
|
|
| else: |
| |
| x0, transfer_idx = get_transfer_index( |
| logits_block, temperature, remasking, mask_block_idx, |
| x_accum[:, block_slice], |
| num_transfer_tokens=num_transfer_tokens[:, i], |
| threshold=threshold, neg_entropy=neg_entropy |
| ) |
| cur = x_accum[:, block_slice].clone() |
| cur[transfer_idx] = x0[transfer_idx] |
| x_accum[:, block_slice] = cur |
|
|
| |
| output = model( |
| x_accum[:, block_slice], |
| past_key_values=past_key_values, |
| use_cache=True |
| ) |
| past_key_values = output.past_key_values |
| nfe += 1 |
|
|
| if dream_style and num_block < num_blocks - 1: |
| |
| next_logits_context = output.logits[:, -1:, :] |
|
|
| return x_accum, nfe |
|
|