| import tensorflow as tf |
| from tensorflow.keras.layers import Dense,LayerNormalization,Embedding |
| from tensorflow.keras import Model |
| import math |
| from dataclasses import dataclass |
|
|
|
|
| @dataclass |
| class ModelArgs: |
| n_positions: int = 2048 |
| vocab_size: int = 51200 |
| n_embd: int = 2560 |
| n_head: int = 32 |
| n_layer: int = 32 |
| rotary_dim: int = 32 |
|
|
|
|
| class RoPEAttention: |
| def __init__(self, dims: int, n_head: int, rotary_dim: int): |
| self.n_head = n_head |
|
|
| self.q_proj = Dense(dims) |
| self.k_proj = Dense(dims) |
| self.v_proj = Dense(dims) |
| self.dense = Dense(dims) |
|
|
| self.rope = RoPE(rotary_dim, traditional=False) |
|
|
| def __call__(self, x, mask=None, cache=None): |
| queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) |
|
|
| |
| n_head = self.n_head |
| B, L, D = queries.shape |
|
|
| |
| queries = tf.transpose(tf.reshape(queries, (B, L, n_head, -1)), (0, 2, 1, 3)) |
| keys = tf.transpose(tf.reshape(keys, (B, L, n_head, -1)), (0, 2, 1, 3)) |
| values = tf.transpose(tf.reshape(values, (B, L, n_head, -1)), (0, 2, 1, 3)) |
|
|
| |
| if cache is not None: |
| key_cache, value_cache = cache |
| queries = self.rope(queries, offset=key_cache.shape[2]) |
| keys = self.rope(keys, offset=key_cache.shape[2]) |
| keys = tf.concat([key_cache, keys], axis=2) |
| values = tf.concat([value_cache, values], axis=2) |
| else: |
| queries = self.rope(queries) |
| keys = self.rope(keys) |
|
|
| queries = tf.cast(queries, tf.float32) |
| keys = tf.cast(keys, tf.float32) |
|
|
| |
| scale = math.sqrt(1 / queries.shape[-1]) |
| scores = tf.matmul((queries * scale), tf.transpose(keys, (0, 1, 3, 2))) |
| if mask is not None: |
| scores = scores + mask |
|
|
| scores = tf.cast(tf.nn.softmax(scores, axis=-1), values.dtype) |
| values_hat = tf.reshape(tf.transpose(tf.matmul(scores, values), (0, 2, 1, 3)), (B, L, -1)) |
|
|
| return self.dense(values_hat), (keys, values) |
|
|
|
|
| class MLP: |
| def __init__(self, dim, hidden_dim): |
| self.fc1 = Dense(hidden_dim) |
| self.fc2 = Dense(dim) |
|
|
| def __call__(self, x): |
| return self.fc2(tf.nn.gelu(self.fc1(x), approximate="precise")) |
|
|
|
|
| class ParallelBlock: |
| def __init__(self, config: ModelArgs): |
| dims = config.n_embd |
| mlp_dims = dims * 4 |
| self.self_attn = RoPEAttention(dims, config.n_head, config.rotary_dim) |
| self.input_layernorm = LayerNormalization() |
| self.mlp = MLP(dims, mlp_dims) |
|
|
| def __call__(self, x, mask, cache): |
| h = self.input_layernorm(x) |
| attn_h, cache = self.self_attn(h, mask, cache) |
| ff_h = self.mlp(h) |
| return attn_h + ff_h + x, cache |
|
|
|
|
| class Transformer: |
| def __init__(self, config: ModelArgs): |
| self.embed_tokens = Embedding(config.vocab_size, config.n_embd) |
| self.layers = [ParallelBlock(config) for i in range(config.n_layer)] |
| self.final_layernorm = LayerNormalization() |
|
|
| def __call__(self, x, mask, cache): |
| x = self.embed_tokens(x) |
| if cache is None: |
| cache = [None] * len(self.layers) |
|
|
| for e, layer in enumerate(self.layers): |
| x, cache[e] = layer(x, mask, cache[e]) |
| return self.final_layernorm(x), cache |
|
|
|
|
| class Phi2(Model): |
| def __init__(self, config: ModelArgs): |
| super(Phi2, self).__init__() |
| self.model = Transformer(config) |
| self.lm_head = Dense(config.vocab_size) |
|
|
| def __call__( |
| self, |
| x, |
| mask = None, |
| cache = None, |
| ): |
| mask = None |
| if x.shape[1] > 1: |
| mask = tf.fill((x.shape[1], x.shape[1]), float("-inf")) |
| mask = tf.linalg.band_part(mask, 0, -1) |
| mask = tf.linalg.set_diag(mask, tf.zeros(x.shape[1])) |
| mask = tf.cast(mask, x.dtype) |
|
|
| y, cache = self.model(x, mask, cache) |
| return self.lm_head(y), cache |
|
|
|
|
| class RoPE: |
| def __init__(self, dims: int, traditional: bool = False, base=None): |
| self.dims = dims |
| self.traditional = traditional |
| self.base = base |
|
|
| def _compute_rope(self, costheta, sintheta, x): |
| x1 = x[..., : self.dims // 2] |
| x2 = x[..., self.dims // 2 : self.dims] |
| rx1 = x1 * costheta - x2 * sintheta |
| rx2 = x1 * sintheta + x2 * costheta |
|
|
| if self.dims < x.shape[-1]: |
| rx = tf.concat([rx1, rx2, x[..., self.dims :]], axis=-1) |
| else: |
| rx = tf.concat([rx1, rx2], axis=-1) |
|
|
| return rx |
|
|
| def _compute_traditional_rope(self, costheta, sintheta, x): |
| x1 = x[..., ::2] |
| x2 = x[..., 1::2] |
| rx1 = x1 * costheta - x2 * sintheta |
| rx2 = x1 * sintheta + x2 * costheta |
|
|
| if self.dims < x.shape[-1]: |
| raise NotImplementedError( |
| "RoPE doesn't implement partial traditional application" |
| ) |
|
|
| rx = tf.concat([rx1[..., None], rx2[..., None]], axis=-1) |
|
|
| return rx |
|
|
| def __call__(self, x, offset: int = 0): |
| shape = x.shape |
| x = tf.reshape(x, (-1, shape[-2], shape[-1])) |
| N = x.shape[1] + offset |
| costheta, sintheta = RoPE.create_cos_sin_theta( |
| N, self.dims, offset=offset, base=self.base, dtype=x.dtype |
| ) |
|
|
| rope = ( |
| self._compute_traditional_rope if self.traditional else self._compute_rope |
| ) |
| rx = rope(costheta, sintheta, x) |
|
|
| return tf.reshape(rx, shape) |
|
|
| @staticmethod |
| def create_cos_sin_theta( |
| N: int, |
| D: int, |
| offset: int = 0, |
| base: float = 10000, |
| dtype=tf.float32, |
| ): |
| D = D // 2 |
| positions = tf.range(offset, N, dtype=dtype) |
| freqs = tf.math.exp( |
| -tf.range(0, D, dtype=dtype) * (tf.math.log(base) / D) |
| ) |
| theta = tf.reshape(positions, (-1, 1)) * tf.reshape(freqs, (1, -1)) |
| costheta = tf.math.cos(theta) |
| sintheta = tf.math.sin(theta) |
|
|
| return costheta, sintheta |