| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Inference-only Gemma model implementation.""" |
|
|
| import tensorflow as tf |
| from tensorflow.keras.layers import Dense |
| from tensorflow.keras import Model |
| import dataclasses |
|
|
|
|
| @dataclasses.dataclass |
| class GemmaConfig: |
| |
| vocab_size: int = 256000 |
| |
| max_position_embeddings: int = 8192 |
| |
| num_hidden_layers: int = 28 |
| |
| num_attention_heads: int = 16 |
| |
| num_key_value_heads: int = 16 |
| |
| hidden_size: int = 3072 |
| |
| intermediate_size: int = 24576 |
| |
| head_dim: int = 256 |
| |
| rms_norm_eps: float = 1e-6 |
|
|
|
|
| def precompute_freqs_cis(dim: int, |
| end: int, |
| theta: float = 10000.0): |
| """Precomputes the frequency cis.""" |
| freqs = 1.0 / (theta**(tf.cast(tf.range(0, dim, 2)[:(dim // 2)], 'float32') / dim)) |
| t = tf.range(end) |
| freqs = tf.cast(tf.experimental.numpy.outer(t, freqs), 'float32') |
| freqs_cis = tf.complex(tf.ones_like(freqs), freqs) |
| return freqs_cis |
|
|
|
|
| def apply_rotary_emb(x, freqs_cis): |
| """Applies the rotary embedding to the query and key tensors.""" |
| x_ = tf.complex( |
| *tf.split(tf.cast(tf.transpose(x, [0, 2, 1, 3]), 'float32'), num_or_size_splits=2, axis=-1), |
| ) |
| x_ = x_ * tf.cast(freqs_cis, x_.dtype) |
| x_out = tf.cast(tf.stack(tf.math.real(x_), |
| tf.math.imag(x_), axis=-1), x.dtype) |
| x_out = tf.concat(tf.split(x_out, num_or_size_splits=2, axis=-1), axis=-2) |
| x_out = tf.transpose(tf.reshape(x_out, (x_out.shape[0], x_out.shape[1], x_out.shape[2], |
| -1)), (0, 2, 1, 3)) |
| return x_out |
|
|
|
|
| class Embedder(tf.keras.layers.Layer): |
| """Embedder module.""" |
| def __init__(self, config: GemmaConfig): |
| self.vocab_size = config.vocab_size |
| self.embed_dim = config.hidden_size |
| self.input_embedding_table = self.add_weight( |
| name='input_embedding_table', |
| shape=(self.vocab_size, self.embed_dim), |
| initializer=tf.keras.initializers.RandomNormal(stddev=0.02), |
| trainable=True |
| ) |
|
|
| def encode(self, x): |
| x = tf.gather(self.input_embedding_table, x) |
| x *= tf.cast(tf.math.sqrt(self.embed_dim), x.dtype) |
| return x |
|
|
| def decode(self, x): |
| return tf.matmul(x, tf.transpose(self.input_embedding_table)) |
|
|
|
|
| class RMSNorm: |
|
|
| def __init__( |
| self, |
| dim: int, |
| eps: float = 1e-6, |
| add_unit_offset: bool = True, |
| ): |
| self.eps = eps |
| self.add_unit_offset = add_unit_offset |
| self.weight = self.add_weight( |
| name='weight', |
| shape=(self.dim,), |
| initializer=tf.keras.initializers.Zeros(), |
| trainable=True |
| ) |
|
|
| def _norm(self, x): |
| return x * tf.math.rsqrt(tf.reduce_mean(tf.math.pow(x, 2), axis=-1, keepdims=True) + self.eps) |
|
|
| def __call__(self, x): |
| x = tf.cast(self._norm(tf.cast(x, 'float32')), x.dtype) |
| if self.add_unit_offset: |
| output = x * (1 + self.weight) |
| else: |
| output = x * self.weight |
| return output |
|
|
|
|
| class GemmaMLP: |
|
|
| def __init__( |
| self, |
| hidden_size: int, |
| intermediate_size: int, |
| ): |
| self.gate_proj = Dense(intermediate_size) |
| self.up_proj = Dense(intermediate_size) |
| self.down_proj = Dense(hidden_size) |
|
|
| def __call__(self, x): |
| gate = self.gate_proj(x) |
| gate = tf.nn.gelu(gate) |
| up = self.up_proj(x) |
| fuse = gate * up |
| outputs = self.down_proj(fuse) |
| return outputs |
|
|
|
|
| class GemmaAttention: |
|
|
| def __init__( |
| self, |
| hidden_size: int, |
| num_heads: int, |
| num_kv_heads: int, |
| head_dim: int, |
| ): |
| self.num_heads = num_heads |
| self.num_kv_heads = num_kv_heads |
|
|
| assert self.num_heads % self.num_kv_heads == 0 |
| self.num_queries_per_kv = self.num_heads // self.num_kv_heads |
|
|
| self.hidden_size = hidden_size |
| self.head_dim = head_dim |
|
|
| self.q_size = self.num_heads * self.head_dim |
| self.kv_size = self.num_kv_heads * self.head_dim |
|
|
| self.scaling = self.head_dim**-0.5 |
|
|
| self.qkv_proj = Dense( |
| (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, |
| ) |
| self.o_proj = Dense( |
| self.hidden_size, |
| ) |
|
|
| def __call__( |
| self, |
| hidden_states, |
| freqs_cis, |
| kv_write_indices, |
| kv_cache, |
| mask, |
| ): |
| hidden_states_shape = hidden_states.shape |
| assert len(hidden_states_shape) == 3 |
|
|
| batch_size, input_len, _ = hidden_states_shape |
|
|
| qkv = self.qkv_proj(hidden_states) |
| xq, xk, xv = tf.split(qkv, [self.q_size, self.kv_size, self.kv_size], |
| axis=-1) |
|
|
| xq = tf.reshape(xq, (batch_size, -1, self.num_heads, self.head_dim)) |
| xk = tf.reshape(xk, (batch_size, -1, self.num_kv_heads, self.head_dim)) |
| xv = tf.reshape(xv, (batch_size, -1, self.num_kv_heads, self.head_dim)) |
|
|
| |
| xq = apply_rotary_emb(xq, freqs_cis=freqs_cis) |
| xk = apply_rotary_emb(xk, freqs_cis=freqs_cis) |
|
|
| |
| |
| k_cache, v_cache = kv_cache |
| k_cache.assign(tf.tensor_scatter_nd_update(k_cache, kv_write_indices, xk)) |
| v_cache.assign(tf.tensor_scatter_nd_update(v_cache, kv_write_indices, xv)) |
|
|
| key = k_cache |
| value = v_cache |
| if self.num_kv_heads != self.num_heads: |
| |
| batch_size, seq_len, num_heads, head_dim = key.shape |
| key = tf.reshape(tf.tile(key[:, :, :, None, :], [1, 1, 1, self.num_queries_per_kv, 1]), |
| [batch_size, seq_len, num_heads * self.num_queries_per_kv, head_dim]) |
| batch_size, seq_len, num_heads, head_dim = value.shape |
| value = tf.reshape(tf.tile(value[:, :, :, None, :], [1, 1, 1, self.num_queries_per_kv, 1]), |
| [batch_size, seq_len, num_heads * self.num_queries_per_kv, head_dim]) |
|
|
| |
| q = tf.transpose(xq, (0, 2, 1, 3)) |
| |
| k = tf.transpose(key, (0, 2, 1, 3)) |
| v = tf.transpose(value, (0, 2, 1, 3)) |
|
|
| |
| scores = tf.matmul(q, tf.transpose(k, (0, 1, 3, 2))) * self.scaling |
| scores = scores + mask |
| scores = tf.cast(tf.nn.softmax(tf.cast(scores, 'float32'), axis=-1), q.dtype) |
|
|
| |
| output = tf.matmul(scores, v) |
|
|
| |
| output = tf.reshape((tf.transpose(output, (0, 2, 1, 3)), |
| (batch_size, input_len, -1))) |
| output = self.o_proj(output) |
| return output |
|
|
|
|
| class GemmaDecoderLayer: |
|
|
| def __init__( |
| self, |
| config: GemmaConfig, |
| ): |
| self.self_attn = GemmaAttention( |
| hidden_size=config.hidden_size, |
| num_heads=config.num_attention_heads, |
| num_kv_heads=config.num_key_value_heads, |
| head_dim=config.head_dim, |
| ) |
| self.mlp = GemmaMLP( |
| hidden_size=config.hidden_size, |
| intermediate_size=config.intermediate_size, |
| ) |
| self.input_layernorm = RMSNorm(config.hidden_size, |
| eps=config.rms_norm_eps) |
| self.post_attention_layernorm = RMSNorm(config.hidden_size, |
| eps=config.rms_norm_eps) |
|
|
| def __call__( |
| self, |
| hidden_states, |
| freqs_cis, |
| kv_write_indices, |
| kv_cache, |
| mask, |
| ): |
| |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| hidden_states = self.self_attn( |
| hidden_states=hidden_states, |
| freqs_cis=freqs_cis, |
| kv_write_indices=kv_write_indices, |
| kv_cache=kv_cache, |
| mask=mask, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| return hidden_states |
|
|
|
|
| class Gemma(Model): |
|
|
| def __init__(self, config: GemmaConfig): |
| super(Gemma, self).__init__() |
| self.config = config |
| self.vocab_size = config.vocab_size |
|
|
| self.embedder = Embedder() |
| self.layers = [] |
| for _ in range(config.num_hidden_layers): |
| self.layers.append(GemmaDecoderLayer(config)) |
| self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.output = Dense(config.vocab_size) |
|
|
| def __call__( |
| self, |
| data, |
| freqs_cis, |
| kv_write_indices, |
| kv_caches, |
| mask |
| ): |
| hidden_states = self.embedder.encode(data) |
| for i in range(len(self.layers)): |
| layer = self.layers[i] |
| hidden_states = layer( |
| hidden_states=hidden_states, |
| freqs_cis=freqs_cis, |
| kv_write_indices=kv_write_indices, |
| kv_cache=kv_caches[i], |
| mask=mask, |
| ) |
| hidden_states = self.norm(hidden_states) |
| logits = self.embedder.decode(hidden_states) |
| return logits |