| import tensorflow as tf |
| from tensorflow.keras.layers import Dense,Conv1d,ZeroPadding1D,LayerNormalization |
| from tensorflow.keras import Model |
| import base64 |
| import gzip |
| import numpy as np |
| from typing import Union |
|
|
|
|
| class ModelDimensions: |
| n_mels: int |
| n_audio_ctx: int |
| n_audio_state: int |
| n_audio_head: int |
| n_audio_layer: int |
| n_vocab: int |
| n_text_ctx: int |
| n_text_state: int |
| n_text_head: int |
| n_text_layer: int |
|
|
|
|
| def sinusoids(length, channels, max_timescale=10000): |
| """Returns sinusoids for positional embedding""" |
| assert channels % 2 == 0 |
| log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) |
| inv_timescales = tf.math.exp(-log_timescale_increment * np.arange(channels // 2)) |
| scaled_time = np.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] |
| return tf.concat([tf.math.sin(scaled_time), tf.math.cos(scaled_time)], axis=1) |
|
|
|
|
| class LayerNorm: |
| def __init__(self, n_state): |
| self.layer_norm = LayerNormalization |
| |
| def __call__(self, x): |
| return tf.cast(self.layer_norm(tf.cast(x, 'float32')), x.dtype) |
|
|
|
|
| class MultiHeadAttention: |
| def __init__(self, n_state: int, n_head: int): |
| self.n_head = n_head |
| self.query = Dense(n_state) |
| self.key = Dense(n_state, use_bias=False) |
| self.value = Dense(n_state) |
| self.out = Dense(n_state) |
|
|
| def __call__( |
| self, |
| x, |
| xa=None, |
| mask=None, |
| kv_cache=None, |
| ): |
| q = self.query(x) |
|
|
| if xa is None: |
| k = self.key(x) |
| v = self.value(x) |
| if kv_cache is not None: |
| k = tf.concat([kv_cache[0], k], axis=1) |
| v = tf.concat([kv_cache[1], v], axis=1) |
| elif kv_cache is None: |
| k = self.key(xa) |
| v = self.value(xa) |
| else: |
| k, v = kv_cache |
|
|
| wv, qk = self.qkv_attention(q, k, v, mask) |
| return self.out(wv), (k, v), qk |
|
|
| def qkv_attention(self, q, k, v, mask=None): |
| n_batch, n_ctx, n_state = q.shape |
| scale = (n_state // self.n_head) ** -0.25 |
| q = q.reshape(*q.shape[:2], self.n_head, -1).transpose(0, 2, 1, 3) * scale |
| k = k.reshape(*k.shape[:2], self.n_head, -1).transpose(0, 2, 3, 1) * scale |
| v = v.reshape(*v.shape[:2], self.n_head, -1).transpose(0, 2, 1, 3) |
|
|
| qk = tf.matmul(q, k) |
| if mask is not None: |
| qk = qk + mask[:n_ctx, :n_ctx] |
| qk = tf.cast(qk, tf.float32) |
|
|
| w = tf.cast(tf.nn.softmax(qk, axis=-1), q.dtype) |
| out = tf.transpose(tf.matmul(w, v), (0, 2, 1, 3)) |
| out = tf.reshape(out, (n_batch, n_ctx, n_state)) |
| return out, qk |
|
|
|
|
| class ResidualAttentionBlock: |
| def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): |
| self.attn = MultiHeadAttention(n_state, n_head) |
| self.attn_ln = LayerNorm(n_state) |
|
|
| self.cross_attn = ( |
| MultiHeadAttention(n_state, n_head) if cross_attention else None |
| ) |
| self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None |
|
|
| n_mlp = n_state * 4 |
| self.mlp1 = Dense(n_mlp) |
| self.mlp2 = Dense(n_state) |
| self.mlp_ln = LayerNorm(n_state) |
|
|
| def __call__(self, x, xa=None, mask=None, kv_cache=None): |
| kv, cross_kv = kv_cache if kv_cache else (None, None) |
| y, kv, _ = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv) |
| x += y |
| cross_qk = None |
| if self.cross_attn: |
| y, cross_kv, cross_qk = self.cross_attn( |
| self.cross_attn_ln(x), xa, kv_cache=cross_kv |
| ) |
| x += y |
| x = x + tf.cast(self.mlp2(tf.nn.gelu(self.mlp1(self.mlp_ln(x))), x.dtype)) |
| return x, (kv, cross_kv), cross_qk |
|
|
|
|
| class AudioEncoder: |
| def __init__( |
| self, |
| n_mels: int, |
| n_ctx: int, |
| n_state: int, |
| n_head: int, |
| n_layer: int, |
| dtype = tf.float16, |
| ): |
| self.zeropadding1d1 = ZeroPadding1D(padding=1) |
| self.conv1 = Conv1d(filters=n_state, kernel_size=3) |
| self.zeropadding1d2 = ZeroPadding1D(padding=1) |
| self.conv2 = Conv1d(filters=n_state, kernel_size=3, strides=2) |
| self._positional_embedding = tf.cast(sinusoids(n_ctx, n_state), dtype) |
|
|
| self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] |
| self.ln_post = LayerNorm(n_state) |
|
|
| def __call__(self, x): |
| x = self.zeropadding1d1(x) |
| x = tf.cast(tf.nn.gelu(self.conv1(x)), x.dtype) |
| x = self.zeropadding1d2(x) |
| x = tf.cast(tf.nn.gelu(self.conv2(x)), x.dtype) |
| assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape" |
| x = x + self._positional_embedding |
|
|
| for block in self.blocks: |
| x, _, _ = block(x) |
|
|
| x = self.ln_post(x) |
| return x |
|
|
|
|
| class TextDecoder(tf.keras.layers.Layer): |
| def __init__( |
| self, |
| n_vocab: int, |
| n_ctx: int, |
| n_state: int, |
| n_head: int, |
| n_layer: int, |
| dtype = tf.float16, |
| ): |
| self.token_embedding = self.add_weight( |
| name='token_embedding', |
| shape=[self.n_vocab, self.n_state], |
| initializer=tf.keras.initializers.RandomNormal(stddev=0.02), |
| trainable=True |
| ) |
| self.positional_embedding = self.add_weight( |
| name='positional_embedding', |
| shape=[self.n_ctx, self.n_state], |
| initializer=tf.keras.initializers.Zeros(), |
| trainable=True |
| ) |
|
|
| self.blocks = [ |
| ResidualAttentionBlock(n_state, n_head, cross_attention=True) |
| for _ in range(n_layer) |
| ] |
| self.ln = LayerNorm(n_state) |
| self._mask = tf.fill((n_ctx, n_ctx), float("-inf")) |
| self._mask = tf.linalg.band_part(self._mask, 0, -1) |
| self._mask = tf.linalg.set_diag(self._mask, tf.zeros(n_ctx)) |
| self._mask = tf.cast(self._mask, dtype) |
|
|
| def __call__(self, x, xa, kv_cache=None): |
| """ |
| x : shape = (batch_size, <= n_ctx) |
| the text tokens |
| xa : shape = (batch_size, n_audio_ctx, n_audio_state) |
| the encoded audio features to be attended on |
| """ |
| offset = kv_cache[0][0][0].shape[1] if kv_cache else 0 |
| x = ( |
| tf.gather(self.token_embedding, x) |
| + self.positional_embedding[offset : offset + x.shape[-1]] |
| ) |
|
|
| if kv_cache is None: |
| kv_cache = [None] * len(self.blocks) |
| cross_qk = [None] * len(self.blocks) |
| for e, block in enumerate(self.blocks): |
| x, kv_cache[e], cross_qk[e] = block( |
| x, xa, mask=self._mask, kv_cache=kv_cache[e] |
| ) |
|
|
| x = self.ln(x) |
| return tf.matmul(x, tf.transpose(self.token_embedding)), kv_cache, cross_qk |
|
|
|
|
| class Whisper(Model): |
| def __init__(self, dims: ModelDimensions, dtype = tf.float16): |
| super(Whisper, self).__init__() |
| self.dims = dims |
| self.encoder = AudioEncoder( |
| self.dims.n_mels, |
| self.dims.n_audio_ctx, |
| self.dims.n_audio_state, |
| self.dims.n_audio_head, |
| self.dims.n_audio_layer, |
| dtype, |
| ) |
| self.decoder = TextDecoder( |
| self.dims.n_vocab, |
| self.dims.n_text_ctx, |
| self.dims.n_text_state, |
| self.dims.n_text_head, |
| self.dims.n_text_layer, |
| dtype, |
| ) |
| |
| |
| all_heads = np.zeros( |
| (self.dims.n_text_layer, self.dims.n_text_head), dtype=bool |
| ) |
| all_heads[self.dims.n_text_layer // 2 :] = True |
| self.alignment_heads = tf.transpose(tf.cast(tf.where(all_heads != 0), dtype=tf.int32)) |
|
|
| def set_alignment_heads(self, dump: Union[bytes, np.ndarray]): |
| if isinstance(dump, np.ndarray): |
| self.alignment_heads = tf.convert_to_tensor(dump) |
| elif isinstance(dump, bytes): |
| array = np.frombuffer( |
| gzip.decompress(base64.b85decode(dump)), dtype=bool |
| ).copy() |
| mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head) |
| self.alignment_heads = tf.transpose(tf.cast(tf.where(mask != 0), dtype=tf.int32)) |
| else: |
| raise ValueError( |
| f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing" |
| " alignment_head information" |
| ) |
|
|
| def embed_audio(self, mel): |
| return self.encoder(mel) |
|
|
| def logits(self, tokens, audio_features): |
| return self.decoder(tokens, audio_features)[0] |
|
|
| def forward_with_cross_qk(self, mel, tokens): |
| logits, _, cross_qk = self.decoder(tokens, self.encoder(mel)) |
| return logits, cross_qk |
|
|
| def __call__(self, mel, tokens): |
| return self.decoder(tokens, self.encoder(mel))[0] |
|
|
| @property |
| def is_multilingual(self): |
| return self.dims.n_vocab >= 51865 |
|
|
| @property |
| def num_languages(self): |
| return self.dims.n_vocab - 51765 - int(self.is_multilingual) |