Upload AuriStream base model code
Browse files- configuration_auristream.py +0 -3
- modeling_auristream.py +35 -81
configuration_auristream.py
CHANGED
|
@@ -50,9 +50,6 @@ class AuriStreamConfig(PretrainedConfig):
|
|
| 50 |
input_conv_kernel_size: int = 0,
|
| 51 |
**kwargs,
|
| 52 |
):
|
| 53 |
-
# Force no weight tying
|
| 54 |
-
kwargs["tie_word_embeddings"] = False
|
| 55 |
-
|
| 56 |
self.vocab_size = vocab_size
|
| 57 |
self.n_embd = n_embd
|
| 58 |
self.n_layer = n_layer
|
|
|
|
| 50 |
input_conv_kernel_size: int = 0,
|
| 51 |
**kwargs,
|
| 52 |
):
|
|
|
|
|
|
|
|
|
|
| 53 |
self.vocab_size = vocab_size
|
| 54 |
self.n_embd = n_embd
|
| 55 |
self.n_layer = n_layer
|
modeling_auristream.py
CHANGED
|
@@ -8,14 +8,22 @@ https://huggingface.co/TuKoResearch/WavCochCausalV8192
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
import math
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
import torch
|
| 14 |
import torch.nn as nn
|
| 15 |
from torch.nn import functional as F
|
| 16 |
|
| 17 |
from transformers import PreTrainedModel
|
| 18 |
-
from transformers.modeling_outputs import CausalLMOutput
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
from .configuration_auristream import AuriStreamConfig
|
| 21 |
|
|
@@ -67,7 +75,7 @@ def apply_rotary_emb(x, cos, sin):
|
|
| 67 |
x2 = x[..., d:]
|
| 68 |
y1 = x1 * cos + x2 * sin
|
| 69 |
y2 = x1 * (-sin) + x2 * cos
|
| 70 |
-
return torch.cat([y1, y2], dim=3)
|
| 71 |
|
| 72 |
|
| 73 |
class CausalSelfAttention(nn.Module):
|
|
@@ -225,6 +233,11 @@ class AuriStreamPreTrainedModel(PreTrainedModel):
|
|
| 225 |
base_model_prefix = "model"
|
| 226 |
supports_gradient_checkpointing = True
|
| 227 |
_no_split_modules = ["Block"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
def _init_weights(self, module):
|
| 230 |
if isinstance(module, nn.Linear):
|
|
@@ -240,8 +253,7 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 240 |
AuriStream speech language model.
|
| 241 |
|
| 242 |
A GPT-like transformer model for cochlear token prediction with optional
|
| 243 |
-
multi-token prediction (MTP) heads for
|
| 244 |
-
novel inference capabilities.
|
| 245 |
|
| 246 |
Developed by Greta Tuckute and Klemen Kotar.
|
| 247 |
"""
|
|
@@ -267,11 +279,10 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 267 |
else:
|
| 268 |
self.future_heads = None
|
| 269 |
|
| 270 |
-
#
|
| 271 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 272 |
|
| 273 |
-
|
| 274 |
-
self.apply(self._init_weights)
|
| 275 |
# Apply special scaled init to residual projections
|
| 276 |
for pn, p in self.named_parameters():
|
| 277 |
if pn.endswith('c_proj.weight'):
|
|
@@ -291,11 +302,8 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 291 |
self,
|
| 292 |
input_ids: Optional[torch.LongTensor] = None,
|
| 293 |
labels: Optional[torch.LongTensor] = None,
|
| 294 |
-
output_logits: Optional[bool] = False,
|
| 295 |
output_hidden_states: Optional[bool] = False,
|
| 296 |
return_dict: Optional[bool] = True,
|
| 297 |
-
up_until_layer: Optional[int] = None,
|
| 298 |
-
normalize_embeddings: Optional[str] = None,
|
| 299 |
# Legacy arguments for compatibility
|
| 300 |
seq: Optional[torch.LongTensor] = None,
|
| 301 |
tgt: Optional[torch.LongTensor] = None,
|
|
@@ -306,27 +314,13 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 306 |
Args:
|
| 307 |
input_ids: Input token IDs of shape (batch_size, seq_len)
|
| 308 |
labels: Target token IDs for computing loss
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
embedding state and final pre-ln_f state. Matches HuggingFace GPT-style.
|
| 314 |
-
return_dict: Whether to return a dict or tuple. If True, return a CausalLMOutput dict,
|
| 315 |
-
otherwise return a tuple.
|
| 316 |
-
up_until_layer: If set, stop the forward pass after this transformer block
|
| 317 |
-
(inclusive) and return intermediate activations. Useful for saving compute.
|
| 318 |
-
normalize_embeddings: 'l2' or 'learned' to normalize hidden states
|
| 319 |
-
seq: Legacy argument (alias for input_ids for backward compatibility)
|
| 320 |
-
tgt: Legacy argument (alias for labels for backward compatibility)
|
| 321 |
|
| 322 |
Returns:
|
| 323 |
-
|
| 324 |
-
CausalLMOutput with fields:
|
| 325 |
-
• loss (optional): Scalar training loss
|
| 326 |
-
• logits: Tensor or list of tensors of prediction logits
|
| 327 |
-
• hidden_states (optional): Tuple of hidden states
|
| 328 |
-
Otherwise:
|
| 329 |
-
Tuple of (logits or list of logits, loss).
|
| 330 |
"""
|
| 331 |
# Handle legacy arguments
|
| 332 |
if seq is not None:
|
|
@@ -338,64 +332,25 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 338 |
tok_emb = self.wte(input_ids)
|
| 339 |
x = self.drop(tok_emb)
|
| 340 |
|
| 341 |
-
# Collect hidden states
|
| 342 |
all_hidden_states = []
|
| 343 |
|
| 344 |
# Forward through transformer blocks
|
| 345 |
-
for
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
break
|
| 349 |
x = block(x)
|
| 350 |
|
| 351 |
-
|
| 352 |
-
if up_until_layer is None or block_idx == len(self.h) - 1:
|
| 353 |
all_hidden_states.append(x)
|
| 354 |
|
| 355 |
-
# Normalize hidden states if requested
|
| 356 |
-
hs_to_return = all_hidden_states
|
| 357 |
-
if output_hidden_states and normalize_embeddings is not None:
|
| 358 |
-
if normalize_embeddings == 'l2': # Preserve direction, get rid of magnitude
|
| 359 |
-
hs_to_return = [F.normalize(h, p=2, dim=-1) for h in all_hidden_states] # Dim -1 is the hidden state dim;
|
| 360 |
-
# after normalization torch.norm(h_norm, p=2, dim=-1) will be 1. I.e. for every token, the hidden state dim norm is 1.
|
| 361 |
-
elif normalize_embeddings == 'learned': # We use the learned RMSNorm (first one; used to prepare embeddings for attn)
|
| 362 |
-
# I.e. these are the representations on which the model computes.
|
| 363 |
-
hs_to_return = []
|
| 364 |
-
L = len(self.h)
|
| 365 |
-
for i, h in enumerate(all_hidden_states):
|
| 366 |
-
if i < L:
|
| 367 |
-
hs_to_return.append(self.h[i].norm1(h))
|
| 368 |
-
else:
|
| 369 |
-
hs_to_return.append(self.ln_f(h)) # Final layer norm (after the main blocks, before LM head(s))
|
| 370 |
-
|
| 371 |
-
# If only hidden states requested (not logits), return early
|
| 372 |
-
if output_hidden_states and not output_logits and labels is None:
|
| 373 |
-
return BaseModelOutput(
|
| 374 |
-
last_hidden_state=x,
|
| 375 |
-
hidden_states=hs_to_return,
|
| 376 |
-
)
|
| 377 |
-
|
| 378 |
# Final layer norm and output head
|
| 379 |
x = self.ln_f(x)
|
| 380 |
logits = self.lm_head(x)
|
| 381 |
|
| 382 |
-
# Collect all logits if requested
|
| 383 |
-
all_logits = [logits] if output_logits else None
|
| 384 |
-
|
| 385 |
-
# Compute future head logits
|
| 386 |
-
# lm_head is the first "standard" lm head which predicts token i+1 (as all GPT models have)
|
| 387 |
-
# self.future_heads holds all the other "MTP" future prediction heads, so self.future_heads
|
| 388 |
-
# corresponds to the head that predicts token i+2 - aka the "second head"
|
| 389 |
-
if self.future_heads is not None:
|
| 390 |
-
for i, head in enumerate(self.future_heads):
|
| 391 |
-
future_logits = head(x[:, :-(i + 1)])
|
| 392 |
-
if output_logits:
|
| 393 |
-
all_logits.append(future_logits)
|
| 394 |
-
|
| 395 |
# Compute loss if labels provided
|
| 396 |
loss = None
|
| 397 |
if labels is not None:
|
| 398 |
-
# compute loss from the first "standard" lm head
|
| 399 |
loss = F.cross_entropy(
|
| 400 |
logits.reshape(-1, self.config.vocab_size),
|
| 401 |
labels.reshape(-1),
|
|
@@ -404,21 +359,21 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 404 |
# Multi-token prediction loss
|
| 405 |
if self.future_heads is not None:
|
| 406 |
for i, head in enumerate(self.future_heads):
|
| 407 |
-
future_logits = head(x[:, :-(i
|
| 408 |
loss = loss + F.cross_entropy(
|
| 409 |
future_logits.reshape(-1, self.config.vocab_size),
|
| 410 |
-
labels[:, (i
|
| 411 |
)
|
| 412 |
|
| 413 |
if not return_dict:
|
| 414 |
if labels is not None:
|
| 415 |
-
return
|
| 416 |
-
return
|
| 417 |
|
| 418 |
return CausalLMOutput(
|
| 419 |
loss=loss,
|
| 420 |
-
logits=
|
| 421 |
-
hidden_states=
|
| 422 |
)
|
| 423 |
|
| 424 |
def sample_logits(
|
|
@@ -489,7 +444,6 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 489 |
torch.manual_seed(seed)
|
| 490 |
|
| 491 |
all_logits = []
|
| 492 |
-
device = seq.device
|
| 493 |
b, t = seq.size()
|
| 494 |
|
| 495 |
# Encode conditioning sequence into KV cache
|
|
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
import math
|
| 11 |
+
import os
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
os.environ.setdefault("USE_TORCH_XLA", "0")
|
| 15 |
|
| 16 |
import torch
|
| 17 |
import torch.nn as nn
|
| 18 |
from torch.nn import functional as F
|
| 19 |
|
| 20 |
from transformers import PreTrainedModel
|
| 21 |
+
from transformers.modeling_outputs import CausalLMOutput
|
| 22 |
+
import transformers.modeling_utils as transformers_modeling_utils
|
| 23 |
+
import transformers.utils.import_utils as transformers_import_utils
|
| 24 |
+
|
| 25 |
+
transformers_import_utils.is_torch_xla_available = lambda *args, **kwargs: False
|
| 26 |
+
transformers_modeling_utils.is_torch_xla_available = lambda *args, **kwargs: False
|
| 27 |
|
| 28 |
from .configuration_auristream import AuriStreamConfig
|
| 29 |
|
|
|
|
| 75 |
x2 = x[..., d:]
|
| 76 |
y1 = x1 * cos + x2 * sin
|
| 77 |
y2 = x1 * (-sin) + x2 * cos
|
| 78 |
+
return torch.cat([y1, y2], dim=3).type_as(x)
|
| 79 |
|
| 80 |
|
| 81 |
class CausalSelfAttention(nn.Module):
|
|
|
|
| 233 |
base_model_prefix = "model"
|
| 234 |
supports_gradient_checkpointing = True
|
| 235 |
_no_split_modules = ["Block"]
|
| 236 |
+
|
| 237 |
+
def __init__(self, config: AuriStreamConfig):
|
| 238 |
+
super().__init__(config)
|
| 239 |
+
if not hasattr(self, "all_tied_weights_keys"):
|
| 240 |
+
self.all_tied_weights_keys = {}
|
| 241 |
|
| 242 |
def _init_weights(self, module):
|
| 243 |
if isinstance(module, nn.Linear):
|
|
|
|
| 253 |
AuriStream speech language model.
|
| 254 |
|
| 255 |
A GPT-like transformer model for cochlear token prediction with optional
|
| 256 |
+
multi-token prediction (MTP) heads for speculative decoding.
|
|
|
|
| 257 |
|
| 258 |
Developed by Greta Tuckute and Klemen Kotar.
|
| 259 |
"""
|
|
|
|
| 279 |
else:
|
| 280 |
self.future_heads = None
|
| 281 |
|
| 282 |
+
# Output head
|
| 283 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 284 |
|
| 285 |
+
self.post_init()
|
|
|
|
| 286 |
# Apply special scaled init to residual projections
|
| 287 |
for pn, p in self.named_parameters():
|
| 288 |
if pn.endswith('c_proj.weight'):
|
|
|
|
| 302 |
self,
|
| 303 |
input_ids: Optional[torch.LongTensor] = None,
|
| 304 |
labels: Optional[torch.LongTensor] = None,
|
|
|
|
| 305 |
output_hidden_states: Optional[bool] = False,
|
| 306 |
return_dict: Optional[bool] = True,
|
|
|
|
|
|
|
| 307 |
# Legacy arguments for compatibility
|
| 308 |
seq: Optional[torch.LongTensor] = None,
|
| 309 |
tgt: Optional[torch.LongTensor] = None,
|
|
|
|
| 314 |
Args:
|
| 315 |
input_ids: Input token IDs of shape (batch_size, seq_len)
|
| 316 |
labels: Target token IDs for computing loss
|
| 317 |
+
output_hidden_states: Whether to return all hidden states
|
| 318 |
+
return_dict: Whether to return a dict or tuple
|
| 319 |
+
seq: Legacy argument (alias for input_ids)
|
| 320 |
+
tgt: Legacy argument (alias for labels)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
Returns:
|
| 323 |
+
CausalLMOutput with logits and optional loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
"""
|
| 325 |
# Handle legacy arguments
|
| 326 |
if seq is not None:
|
|
|
|
| 332 |
tok_emb = self.wte(input_ids)
|
| 333 |
x = self.drop(tok_emb)
|
| 334 |
|
| 335 |
+
# Collect hidden states if requested
|
| 336 |
all_hidden_states = []
|
| 337 |
|
| 338 |
# Forward through transformer blocks
|
| 339 |
+
for block in self.h:
|
| 340 |
+
if output_hidden_states:
|
| 341 |
+
all_hidden_states.append(x)
|
|
|
|
| 342 |
x = block(x)
|
| 343 |
|
| 344 |
+
if output_hidden_states:
|
|
|
|
| 345 |
all_hidden_states.append(x)
|
| 346 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
# Final layer norm and output head
|
| 348 |
x = self.ln_f(x)
|
| 349 |
logits = self.lm_head(x)
|
| 350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
# Compute loss if labels provided
|
| 352 |
loss = None
|
| 353 |
if labels is not None:
|
|
|
|
| 354 |
loss = F.cross_entropy(
|
| 355 |
logits.reshape(-1, self.config.vocab_size),
|
| 356 |
labels.reshape(-1),
|
|
|
|
| 359 |
# Multi-token prediction loss
|
| 360 |
if self.future_heads is not None:
|
| 361 |
for i, head in enumerate(self.future_heads):
|
| 362 |
+
future_logits = head(x[:, :-(i+1)])
|
| 363 |
loss = loss + F.cross_entropy(
|
| 364 |
future_logits.reshape(-1, self.config.vocab_size),
|
| 365 |
+
labels[:, (i+1):].reshape(-1),
|
| 366 |
)
|
| 367 |
|
| 368 |
if not return_dict:
|
| 369 |
if labels is not None:
|
| 370 |
+
return logits, loss
|
| 371 |
+
return logits, None
|
| 372 |
|
| 373 |
return CausalLMOutput(
|
| 374 |
loss=loss,
|
| 375 |
+
logits=logits,
|
| 376 |
+
hidden_states=all_hidden_states if output_hidden_states else None,
|
| 377 |
)
|
| 378 |
|
| 379 |
def sample_logits(
|
|
|
|
| 444 |
torch.manual_seed(seed)
|
| 445 |
|
| 446 |
all_logits = []
|
|
|
|
| 447 |
b, t = seq.size()
|
| 448 |
|
| 449 |
# Encode conditioning sequence into KV cache
|