drizzlezyk commited on
Commit
71897e7
·
verified ·
1 Parent(s): c708886

Upload inference/generation_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference/generation_utils.py +313 -0
inference/generation_utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2
+ # Copyright 2025 NVIDIA CORPORATION & AFFILIATES
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # SPDX-License-Identifier: Apache-2.0
17
+ # Modified from Dream repos: https://github.com/HKUNLP/Dream
18
+
19
+
20
+
21
+ from dataclasses import dataclass
22
+ from collections.abc import Iterable
23
+ from typing import Any, Dict, Optional, Tuple, Union
24
+
25
+ import torch
26
+ try:
27
+ import torch_npu
28
+ except ImportError as e:
29
+ pass
30
+ import torch.distributions as dists
31
+ from torch.nn import functional as F
32
+
33
+ from transformers.cache_utils import Cache, DynamicCache
34
+
35
+
36
+ def top_p_logits(logits, top_p=None):
37
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
38
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
39
+ sorted_indices_to_remove = cumulative_probs > top_p
40
+ # Shift the indices to the right to keep the first token above the threshold
41
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
42
+ sorted_indices_to_remove[..., 0] = 0
43
+
44
+ mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
45
+ mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
46
+ logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
47
+ return logits
48
+
49
+
50
+ def top_k_logits(logits, top_k=None):
51
+ top_k = min(top_k, logits.size(-1)) # Safety check
52
+ # Remove all tokens with a probability less than the last token of the top-k
53
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
54
+ logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
55
+ return logits
56
+
57
+
58
+ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
59
+
60
+ if temperature > 0:
61
+ logits = logits / temperature
62
+ if top_p is not None and top_p < 1:
63
+ logits = top_p_logits(logits, top_p)
64
+ if top_k is not None:
65
+ logits = top_k_logits(logits, top_k)
66
+ probs = torch.softmax(logits, dim=-1)
67
+
68
+ if temperature > 0:
69
+ try:
70
+ x0 = dists.Categorical(probs=probs).sample()
71
+ confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
72
+ except:
73
+ confidence, x0 = probs.max(dim=-1)
74
+ else:
75
+ confidence, x0 = probs.max(dim=-1)
76
+
77
+ if margin_confidence:
78
+ sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
79
+ # Extract top1 and top2 probabilities
80
+ top1_probs = sorted_probs[:, 0]
81
+ top2_probs = sorted_probs[:, 1]
82
+ # Calculate confidence as top1 - top2
83
+ confidence = top1_probs - top2_probs
84
+
85
+ if neg_entropy:
86
+ epsilon = 1e-10
87
+ log_probs = torch.log(probs + epsilon)
88
+ confidence = torch.sum(probs * log_probs, dim=-1)
89
+
90
+ return confidence, x0
91
+
92
+
93
+ class BlockDynamicCache(DynamicCache):
94
+ """
95
+ When `skip_cache_update` is True, this class does NOT update the cached key and value states.
96
+ Instead, it concatenates the current states with the original cached states along the sequence dimension
97
+ and returns the result.
98
+
99
+ Example:
100
+
101
+ ```python
102
+ >>> past_key_values = BlockDynamicCache()
103
+ >>> past_key_values.skip_cache_update = True
104
+ >>> outputs.past_key_values
105
+ ```
106
+ """
107
+ def __init__(self, _distributed_cache_data: Optional[Iterable] = None) -> None:
108
+ """
109
+ Initialize a BlockDynamicCache instance.
110
+
111
+ skip_cache_update is False by default.
112
+ """
113
+ super().__init__(_distributed_cache_data)
114
+ self.skip_cache_update = False
115
+
116
+ def update(
117
+ self,
118
+ key_states: torch.Tensor,
119
+ value_states: torch.Tensor,
120
+ layer_idx: int,
121
+ cache_kwargs: Optional[dict[str, Any]] = None,
122
+ ) -> tuple[torch.Tensor, torch.Tensor]:
123
+ """
124
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
125
+
126
+ Behavior depends on the `skip_cache_update` flag:
127
+ - If `skip_cache_update` is True:
128
+ * Does NOT update the stored cache.
129
+ * Concatenates the current `key_states` and `value_states`
130
+ with the original cached states along the sequence dimension.
131
+ * Returns the concatenated result.
132
+ - If `skip_cache_update` is False:
133
+ * Uses the parent class update logic to update the cache.
134
+
135
+ Parameters:
136
+ key_states (`torch.Tensor`):
137
+ The new key states to cache.
138
+ value_states (`torch.Tensor`):
139
+ The new value states to cache.
140
+ layer_idx (`int`):
141
+ The index of the layer to cache the states for.
142
+ cache_kwargs (`dict[str, Any]`, `optional`):
143
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
144
+
145
+ Returns:
146
+ Tuple[torch.Tensor, torch.Tensor]:
147
+ The updated key and value states after concatenation or update.
148
+ When `skip_cache_update=True`, returns the concatenated tensor without modifying cache.
149
+ When `skip_cache_update=False`, returns the result from the parent class.
150
+ """
151
+ if self.skip_cache_update:
152
+ key_cache = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
153
+ value_cache = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
154
+ return key_cache, value_cache
155
+ return super().update(key_states, value_states, layer_idx, cache_kwargs)
156
+
157
+
158
+
159
+ @torch.no_grad()
160
+ def diffusion_generate(
161
+ model,
162
+ inputs: Optional[torch.Tensor] = None,
163
+ top_p: Optional[int] = None,
164
+ top_k: Optional[int] = None,
165
+ threshold: Optional[float] = 0.9,
166
+ num_small_blocks: Optional[int] = 1,
167
+ **kwargs,
168
+ ):
169
+ block_length=kwargs.pop("block_length", 32)
170
+ attention_mask = kwargs.pop("attention_mask", None)
171
+ alg = kwargs.get("alg", 'origin')
172
+ temperature = kwargs.get("temperature", 0.0)
173
+ mask_token_id = kwargs.get("mask_token_id", None)
174
+ eos_token_id = kwargs.get("eos_token_id", None)
175
+
176
+ if mask_token_id is None:
177
+ raise ValueError("mask_token_id must be provided")
178
+
179
+ if eos_token_id is None:
180
+ raise ValueError("eos_token_id must be provided")
181
+
182
+ if inputs is None:
183
+ raise ValueError("inputs must be provided")
184
+
185
+ if attention_mask is None:
186
+ raise ValueError("attention_mask must be provided")
187
+
188
+
189
+ input_ids = inputs
190
+
191
+ if type(kwargs.get('max_new_tokens', None)) is int:
192
+ max_length = kwargs.get('max_new_tokens') + input_ids.shape[-1]
193
+ elif kwargs.get('max_length', None) is None:
194
+ raise ValueError("Pass max_new_tokens or max_length")
195
+
196
+ prompt_length = input_ids.shape[1]
197
+ if (max_length - prompt_length) % block_length != 0:
198
+ raise ValueError(
199
+ f"The token length ({max_length - prompt_length}) "
200
+ f"cannot be evenly divided by the block length ({block_length})."
201
+ )
202
+
203
+ num_blocks = (max_length - prompt_length) // block_length
204
+ device = model.device
205
+ position_ids = torch.arange(max_length, device=device).unsqueeze(0)
206
+ # pad input_ids to max_length
207
+ x = F.pad(input_ids, (0, max_length - prompt_length), value=mask_token_id)
208
+
209
+ # Initialize cache for the prompt
210
+ past_key_values = BlockDynamicCache()
211
+
212
+ causal_mask = torch.tril(torch.ones(max_length, max_length, device=device, dtype=torch.bool))[None, None, :, :]
213
+
214
+ padding_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
215
+ position_ids = padding_mask.long().cumsum(-1) - 1
216
+ position_ids.masked_fill_(padding_mask == 0, 1)
217
+ # [B, N] --> [B, 1, N, N]
218
+ padding_mask = torch.logical_and(
219
+ padding_mask.unsqueeze(1).unsqueeze(-2),
220
+ padding_mask.unsqueeze(1).unsqueeze(-1),
221
+ )
222
+ attention_mask = padding_mask & causal_mask
223
+
224
+
225
+ # Prefill stage
226
+ if prompt_length > 0:
227
+ cur_x = x[:, :prompt_length]
228
+ cur_attn_mask = attention_mask[:, :, :prompt_length, :prompt_length]
229
+ cur_position_ids = position_ids[:, :prompt_length]
230
+ output = model(cur_x,
231
+ attention_mask=cur_attn_mask,
232
+ position_ids=cur_position_ids,
233
+ past_key_values=past_key_values,
234
+ use_cache=True
235
+ )
236
+ past_key_values = output.past_key_values
237
+
238
+ logits = output.logits[:, -1:]
239
+ confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k)
240
+ x[:, prompt_length:prompt_length + 1] = x0
241
+
242
+ # Process each block
243
+ for num_block in range(num_blocks):
244
+ block_start = prompt_length + num_block * block_length
245
+ block_end = prompt_length + (num_block + 1) * block_length
246
+ cur_x = x[:, block_start:block_end]
247
+ cur_attn_mask = attention_mask[:, :, block_start:block_end, :block_end]
248
+ cur_padding_mask = padding_mask[:, :, block_start:block_end, :block_end]
249
+ cur_position_ids = position_ids[:, block_start:block_end]
250
+ # Use cache for generation
251
+ small_block_length = block_length // num_small_blocks
252
+
253
+ if block_length % num_small_blocks != 0:
254
+ raise ValueError(
255
+ f"block_length ({block_length}) must be divisible by num_small_blocks ({num_small_blocks})."
256
+ )
257
+
258
+ # Just concatenates current key value states, do not update key value cache
259
+ past_key_values.skip_cache_update = True
260
+ for small_block_idx in range(num_small_blocks):
261
+ small_block_start = small_block_idx * small_block_length
262
+ small_block_end = small_block_start + small_block_length
263
+
264
+ while True:
265
+ sub_mask_index = (cur_x[:, small_block_start:small_block_end] == mask_token_id)
266
+ if sub_mask_index.sum() == 0:
267
+ break
268
+
269
+ output = model(cur_x,
270
+ attention_mask=cur_padding_mask,
271
+ position_ids=cur_position_ids,
272
+ past_key_values=past_key_values,
273
+ use_cache=True)
274
+ logits = output.logits
275
+ logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
276
+ logits = logits[:, small_block_start:small_block_end]
277
+
278
+ confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k,
279
+ neg_entropy=(alg == 'entropy'), margin_confidence=(alg == 'topk_margin'))
280
+ confidence = torch.where(sub_mask_index, confidence, -torch.inf)
281
+ transfer_index = (F.one_hot(torch.max(confidence, dim=1)[1], num_classes=small_block_length) == 1)
282
+ if alg == 'confidence_threshold':
283
+ transfer_index |= (confidence > threshold)
284
+ cur_x[:, small_block_start:small_block_end][transfer_index] = x0[transfer_index]
285
+
286
+ if eos_token_id and (x[:, prompt_length:] == eos_token_id).any(dim=1).all():
287
+ return x
288
+
289
+ # Store kv cache
290
+ past_key_values.skip_cache_update = False
291
+ output = model(cur_x,
292
+ attention_mask=cur_attn_mask,
293
+ position_ids=cur_position_ids,
294
+ past_key_values=past_key_values,
295
+ use_cache=True,
296
+ )
297
+ past_key_values = output.past_key_values
298
+ if num_block < num_blocks - 1:
299
+ logits = output.logits[:, -1:]
300
+ confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k)
301
+ x[:, block_end:block_end + 1] = x0
302
+
303
+ return x
304
+
305
+
306
+
307
+
308
+
309
+
310
+
311
+
312
+
313
+