klemenk commited on
Commit
ee0c1d2
·
verified ·
1 Parent(s): f47bc71

Upload AuriStream base model code

Browse files
configuration_auristream.py CHANGED
@@ -50,9 +50,6 @@ class AuriStreamConfig(PretrainedConfig):
50
  input_conv_kernel_size: int = 0,
51
  **kwargs,
52
  ):
53
- # Force no weight tying
54
- kwargs["tie_word_embeddings"] = False
55
-
56
  self.vocab_size = vocab_size
57
  self.n_embd = n_embd
58
  self.n_layer = n_layer
 
50
  input_conv_kernel_size: int = 0,
51
  **kwargs,
52
  ):
 
 
 
53
  self.vocab_size = vocab_size
54
  self.n_embd = n_embd
55
  self.n_layer = n_layer
modeling_auristream.py CHANGED
@@ -8,14 +8,22 @@ https://huggingface.co/TuKoResearch/WavCochCausalV8192
8
  """
9
 
10
  import math
11
- from typing import Optional, List
 
 
 
12
 
13
  import torch
14
  import torch.nn as nn
15
  from torch.nn import functional as F
16
 
17
  from transformers import PreTrainedModel
18
- from transformers.modeling_outputs import CausalLMOutput, BaseModelOutput
 
 
 
 
 
19
 
20
  from .configuration_auristream import AuriStreamConfig
21
 
@@ -67,7 +75,7 @@ def apply_rotary_emb(x, cos, sin):
67
  x2 = x[..., d:]
68
  y1 = x1 * cos + x2 * sin
69
  y2 = x1 * (-sin) + x2 * cos
70
- return torch.cat([y1, y2], dim=3)
71
 
72
 
73
  class CausalSelfAttention(nn.Module):
@@ -225,6 +233,11 @@ class AuriStreamPreTrainedModel(PreTrainedModel):
225
  base_model_prefix = "model"
226
  supports_gradient_checkpointing = True
227
  _no_split_modules = ["Block"]
 
 
 
 
 
228
 
229
  def _init_weights(self, module):
230
  if isinstance(module, nn.Linear):
@@ -240,8 +253,7 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
240
  AuriStream speech language model.
241
 
242
  A GPT-like transformer model for cochlear token prediction with optional
243
- multi-token prediction (MTP) heads for improved representation learning and
244
- novel inference capabilities.
245
 
246
  Developed by Greta Tuckute and Klemen Kotar.
247
  """
@@ -267,11 +279,10 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
267
  else:
268
  self.future_heads = None
269
 
270
- # "Standard" LM output head
271
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
272
 
273
- # Initialize weights
274
- self.apply(self._init_weights)
275
  # Apply special scaled init to residual projections
276
  for pn, p in self.named_parameters():
277
  if pn.endswith('c_proj.weight'):
@@ -291,11 +302,8 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
291
  self,
292
  input_ids: Optional[torch.LongTensor] = None,
293
  labels: Optional[torch.LongTensor] = None,
294
- output_logits: Optional[bool] = False,
295
  output_hidden_states: Optional[bool] = False,
296
  return_dict: Optional[bool] = True,
297
- up_until_layer: Optional[int] = None,
298
- normalize_embeddings: Optional[str] = None,
299
  # Legacy arguments for compatibility
300
  seq: Optional[torch.LongTensor] = None,
301
  tgt: Optional[torch.LongTensor] = None,
@@ -306,27 +314,13 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
306
  Args:
307
  input_ids: Input token IDs of shape (batch_size, seq_len)
308
  labels: Target token IDs for computing loss
309
- output_logits: Whether to return all logits (including from future heads).
310
- The first element corresponds to the standard next-token head (prediction of i+1);
311
- subsequent elements correspond to future heads predicting tokens i+2, i+3, etc.
312
- output_hidden_states: Whether to return all hidden states, including the input
313
- embedding state and final pre-ln_f state. Matches HuggingFace GPT-style.
314
- return_dict: Whether to return a dict or tuple. If True, return a CausalLMOutput dict,
315
- otherwise return a tuple.
316
- up_until_layer: If set, stop the forward pass after this transformer block
317
- (inclusive) and return intermediate activations. Useful for saving compute.
318
- normalize_embeddings: 'l2' or 'learned' to normalize hidden states
319
- seq: Legacy argument (alias for input_ids for backward compatibility)
320
- tgt: Legacy argument (alias for labels for backward compatibility)
321
 
322
  Returns:
323
- If return_dict is True:
324
- CausalLMOutput with fields:
325
- • loss (optional): Scalar training loss
326
- • logits: Tensor or list of tensors of prediction logits
327
- • hidden_states (optional): Tuple of hidden states
328
- Otherwise:
329
- Tuple of (logits or list of logits, loss).
330
  """
331
  # Handle legacy arguments
332
  if seq is not None:
@@ -338,64 +332,25 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
338
  tok_emb = self.wte(input_ids)
339
  x = self.drop(tok_emb)
340
 
341
- # Collect hidden states
342
  all_hidden_states = []
343
 
344
  # Forward through transformer blocks
345
- for block_idx, block in enumerate(self.h):
346
- all_hidden_states.append(x)
347
- if up_until_layer is not None and block_idx == up_until_layer:
348
- break
349
  x = block(x)
350
 
351
- # Append final pre-ln_f state if we didn't exit early
352
- if up_until_layer is None or block_idx == len(self.h) - 1:
353
  all_hidden_states.append(x)
354
 
355
- # Normalize hidden states if requested
356
- hs_to_return = all_hidden_states
357
- if output_hidden_states and normalize_embeddings is not None:
358
- if normalize_embeddings == 'l2': # Preserve direction, get rid of magnitude
359
- hs_to_return = [F.normalize(h, p=2, dim=-1) for h in all_hidden_states] # Dim -1 is the hidden state dim;
360
- # after normalization torch.norm(h_norm, p=2, dim=-1) will be 1. I.e. for every token, the hidden state dim norm is 1.
361
- elif normalize_embeddings == 'learned': # We use the learned RMSNorm (first one; used to prepare embeddings for attn)
362
- # I.e. these are the representations on which the model computes.
363
- hs_to_return = []
364
- L = len(self.h)
365
- for i, h in enumerate(all_hidden_states):
366
- if i < L:
367
- hs_to_return.append(self.h[i].norm1(h))
368
- else:
369
- hs_to_return.append(self.ln_f(h)) # Final layer norm (after the main blocks, before LM head(s))
370
-
371
- # If only hidden states requested (not logits), return early
372
- if output_hidden_states and not output_logits and labels is None:
373
- return BaseModelOutput(
374
- last_hidden_state=x,
375
- hidden_states=hs_to_return,
376
- )
377
-
378
  # Final layer norm and output head
379
  x = self.ln_f(x)
380
  logits = self.lm_head(x)
381
 
382
- # Collect all logits if requested
383
- all_logits = [logits] if output_logits else None
384
-
385
- # Compute future head logits
386
- # lm_head is the first "standard" lm head which predicts token i+1 (as all GPT models have)
387
- # self.future_heads holds all the other "MTP" future prediction heads, so self.future_heads
388
- # corresponds to the head that predicts token i+2 - aka the "second head"
389
- if self.future_heads is not None:
390
- for i, head in enumerate(self.future_heads):
391
- future_logits = head(x[:, :-(i + 1)])
392
- if output_logits:
393
- all_logits.append(future_logits)
394
-
395
  # Compute loss if labels provided
396
  loss = None
397
  if labels is not None:
398
- # compute loss from the first "standard" lm head
399
  loss = F.cross_entropy(
400
  logits.reshape(-1, self.config.vocab_size),
401
  labels.reshape(-1),
@@ -404,21 +359,21 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
404
  # Multi-token prediction loss
405
  if self.future_heads is not None:
406
  for i, head in enumerate(self.future_heads):
407
- future_logits = head(x[:, :-(i + 1)])
408
  loss = loss + F.cross_entropy(
409
  future_logits.reshape(-1, self.config.vocab_size),
410
- labels[:, (i + 1):].reshape(-1),
411
  )
412
 
413
  if not return_dict:
414
  if labels is not None:
415
- return (all_logits if output_logits else logits), loss
416
- return (all_logits if output_logits else logits), None
417
 
418
  return CausalLMOutput(
419
  loss=loss,
420
- logits=all_logits if output_logits else logits,
421
- hidden_states=hs_to_return if output_hidden_states else None,
422
  )
423
 
424
  def sample_logits(
@@ -489,7 +444,6 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
489
  torch.manual_seed(seed)
490
 
491
  all_logits = []
492
- device = seq.device
493
  b, t = seq.size()
494
 
495
  # Encode conditioning sequence into KV cache
 
8
  """
9
 
10
  import math
11
+ import os
12
+ from typing import Optional
13
+
14
+ os.environ.setdefault("USE_TORCH_XLA", "0")
15
 
16
  import torch
17
  import torch.nn as nn
18
  from torch.nn import functional as F
19
 
20
  from transformers import PreTrainedModel
21
+ from transformers.modeling_outputs import CausalLMOutput
22
+ import transformers.modeling_utils as transformers_modeling_utils
23
+ import transformers.utils.import_utils as transformers_import_utils
24
+
25
+ transformers_import_utils.is_torch_xla_available = lambda *args, **kwargs: False
26
+ transformers_modeling_utils.is_torch_xla_available = lambda *args, **kwargs: False
27
 
28
  from .configuration_auristream import AuriStreamConfig
29
 
 
75
  x2 = x[..., d:]
76
  y1 = x1 * cos + x2 * sin
77
  y2 = x1 * (-sin) + x2 * cos
78
+ return torch.cat([y1, y2], dim=3).type_as(x)
79
 
80
 
81
  class CausalSelfAttention(nn.Module):
 
233
  base_model_prefix = "model"
234
  supports_gradient_checkpointing = True
235
  _no_split_modules = ["Block"]
236
+
237
+ def __init__(self, config: AuriStreamConfig):
238
+ super().__init__(config)
239
+ if not hasattr(self, "all_tied_weights_keys"):
240
+ self.all_tied_weights_keys = {}
241
 
242
  def _init_weights(self, module):
243
  if isinstance(module, nn.Linear):
 
253
  AuriStream speech language model.
254
 
255
  A GPT-like transformer model for cochlear token prediction with optional
256
+ multi-token prediction (MTP) heads for speculative decoding.
 
257
 
258
  Developed by Greta Tuckute and Klemen Kotar.
259
  """
 
279
  else:
280
  self.future_heads = None
281
 
282
+ # Output head
283
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
284
 
285
+ self.post_init()
 
286
  # Apply special scaled init to residual projections
287
  for pn, p in self.named_parameters():
288
  if pn.endswith('c_proj.weight'):
 
302
  self,
303
  input_ids: Optional[torch.LongTensor] = None,
304
  labels: Optional[torch.LongTensor] = None,
 
305
  output_hidden_states: Optional[bool] = False,
306
  return_dict: Optional[bool] = True,
 
 
307
  # Legacy arguments for compatibility
308
  seq: Optional[torch.LongTensor] = None,
309
  tgt: Optional[torch.LongTensor] = None,
 
314
  Args:
315
  input_ids: Input token IDs of shape (batch_size, seq_len)
316
  labels: Target token IDs for computing loss
317
+ output_hidden_states: Whether to return all hidden states
318
+ return_dict: Whether to return a dict or tuple
319
+ seq: Legacy argument (alias for input_ids)
320
+ tgt: Legacy argument (alias for labels)
 
 
 
 
 
 
 
 
321
 
322
  Returns:
323
+ CausalLMOutput with logits and optional loss
 
 
 
 
 
 
324
  """
325
  # Handle legacy arguments
326
  if seq is not None:
 
332
  tok_emb = self.wte(input_ids)
333
  x = self.drop(tok_emb)
334
 
335
+ # Collect hidden states if requested
336
  all_hidden_states = []
337
 
338
  # Forward through transformer blocks
339
+ for block in self.h:
340
+ if output_hidden_states:
341
+ all_hidden_states.append(x)
 
342
  x = block(x)
343
 
344
+ if output_hidden_states:
 
345
  all_hidden_states.append(x)
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  # Final layer norm and output head
348
  x = self.ln_f(x)
349
  logits = self.lm_head(x)
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  # Compute loss if labels provided
352
  loss = None
353
  if labels is not None:
 
354
  loss = F.cross_entropy(
355
  logits.reshape(-1, self.config.vocab_size),
356
  labels.reshape(-1),
 
359
  # Multi-token prediction loss
360
  if self.future_heads is not None:
361
  for i, head in enumerate(self.future_heads):
362
+ future_logits = head(x[:, :-(i+1)])
363
  loss = loss + F.cross_entropy(
364
  future_logits.reshape(-1, self.config.vocab_size),
365
+ labels[:, (i+1):].reshape(-1),
366
  )
367
 
368
  if not return_dict:
369
  if labels is not None:
370
+ return logits, loss
371
+ return logits, None
372
 
373
  return CausalLMOutput(
374
  loss=loss,
375
+ logits=logits,
376
+ hidden_states=all_hidden_states if output_hidden_states else None,
377
  )
378
 
379
  def sample_logits(
 
444
  torch.manual_seed(seed)
445
 
446
  all_logits = []
 
447
  b, t = seq.size()
448
 
449
  # Encode conditioning sequence into KV cache