| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Gemma adaptation for Pi, taken from big_vision. |
| |
| We follow this einsum axis naming convention: |
| B: batch |
| T: query length |
| S: k/v length |
| N: num query heads |
| K: num k/v heads |
| G: num query heads per k/v head |
| H: head dim |
| D: d_model ("features") |
| """ |
|
|
| from collections.abc import Sequence |
| import dataclasses |
| from typing import Literal, TypeAlias |
|
|
| import einops |
| import flax.linen as nn |
| import jax |
| import jax.numpy as jnp |
|
|
| import openpi.models.lora as lora |
| import openpi.shared.array_typing as at |
| import openpi.training.sharding as sharding |
|
|
| PALIGEMMA_VOCAB_SIZE = 257_152 |
|
|
|
|
| @dataclasses.dataclass |
| class Config: |
| width: int |
| depth: int |
| mlp_dim: int |
| num_heads: int |
| num_kv_heads: int |
| head_dim: int |
| lora_configs: dict[str, lora.LoRAConfig] = dataclasses.field(default_factory=dict) |
|
|
|
|
| Variant = Literal["dummy", "gemma_300m", "gemma_2b", "gemma_2b_lora"] |
|
|
|
|
| def get_config(variant: Variant) -> Config: |
| """Returns config for specified gemma variant.""" |
| if variant == "dummy": |
| return Config( |
| width=64, |
| depth=4, |
| mlp_dim=128, |
| num_heads=8, |
| num_kv_heads=1, |
| head_dim=16, |
| ) |
| if variant == "gemma_300m": |
| |
| return Config( |
| width=1024, |
| depth=18, |
| mlp_dim=4096, |
| num_heads=8, |
| num_kv_heads=1, |
| head_dim=256, |
| ) |
| if variant == "gemma_2b": |
| return Config( |
| width=2048, |
| depth=18, |
| mlp_dim=16_384, |
| num_heads=8, |
| num_kv_heads=1, |
| head_dim=256, |
| ) |
| if variant == "gemma_2b_lora": |
| return Config( |
| width=2048, |
| depth=18, |
| mlp_dim=16_384, |
| num_heads=8, |
| num_kv_heads=1, |
| head_dim=256, |
| lora_configs={"attn": lora.LoRAConfig(rank=16, alpha=16.0), "ffn": lora.LoRAConfig(rank=16, alpha=16.0)}, |
| ) |
| if variant == "gemma_300m_lora": |
| |
| return Config( |
| width=1024, |
| depth=18, |
| mlp_dim=4096, |
| num_heads=8, |
| num_kv_heads=1, |
| head_dim=256, |
| lora_configs={"attn": lora.LoRAConfig(rank=32, alpha=32.0), "ffn": lora.LoRAConfig(rank=32, alpha=32.0)}, |
| ) |
| raise ValueError(f"Unknown variant: {variant}") |
|
|
|
|
| @at.typecheck |
| class RMSNorm(nn.Module): |
| @nn.compact |
| def __call__(self, x): |
| dtype = x.dtype |
| scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1])) |
| var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) |
| normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) |
| normed_inputs = normed_inputs * ( |
| 1 + scale |
| ) |
| return normed_inputs.astype(dtype) |
|
|
|
|
| @at.typecheck |
| class Embedder(nn.Module): |
| """Embedder module.""" |
|
|
| vocab_size: int |
| embed_dim: int |
|
|
| def setup(self): |
| self.input_embedding_table = self.param( |
| "input_embedding", |
| nn.initializers.normal(), |
| (self.vocab_size, self.embed_dim), |
| ) |
|
|
| def encode(self, x): |
| x = self.input_embedding_table[(x,)] |
| x *= jnp.sqrt(self.embed_dim).astype(x.dtype) |
| return x |
|
|
| def decode(self, x): |
| return jnp.dot(x, self.input_embedding_table.T) |
|
|
|
|
| @at.typecheck |
| class Attention(nn.Module): |
| """Attention module.""" |
|
|
| configs: Sequence[Config] |
|
|
| @nn.compact |
| def __call__(self, xs, positions, attn_mask, kv_cache): |
| |
| assert all(config.head_dim == self.configs[0].head_dim for config in self.configs) |
| assert all(config.num_heads == self.configs[0].num_heads for config in self.configs) |
| assert all(config.num_kv_heads == self.configs[0].num_kv_heads for config in self.configs) |
|
|
| dtype = next(x.dtype for x in xs if x is not None) |
|
|
| qkvs = [] |
| for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)): |
| if x is None: |
| continue |
| if config.num_kv_heads == config.num_heads: |
| qkv_einsum = lora.Einsum( |
| shape=(3, config.num_heads, config.width, config.head_dim), |
| name=_name("qkv_einsum", i), |
| init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), |
| lora_config=config.lora_configs.get("attn"), |
| ) |
| qkvs.append(qkv_einsum("BSD,3KDH->3BSKH", x)) |
| else: |
| q_einsum = lora.Einsum( |
| shape=(config.num_heads, config.width, config.head_dim), |
| name=_name("q_einsum", i), |
| init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), |
| lora_config=config.lora_configs.get("attn"), |
| ) |
| q = q_einsum("BTD,NDH->BTNH", x) |
| kv_einsum = lora.Einsum( |
| shape=(2, config.num_kv_heads, config.width, config.head_dim), |
| name=_name("kv_einsum", i), |
| init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), |
| lora_config=config.lora_configs.get("attn"), |
| ) |
| k, v = kv_einsum("BSD,2KDH->2BSKH", x) |
| qkvs.append((q, k, v)) |
|
|
| q, k, v = (jnp.concatenate(y, axis=1) for y in zip(*qkvs, strict=True)) |
|
|
| q = _apply_rope(q, positions=positions) |
| q *= self.configs[0].head_dim ** -0.5 |
|
|
| k = _apply_rope(k, positions=positions) |
|
|
| |
| assert q.dtype == k.dtype == v.dtype == dtype |
|
|
| if kv_cache is not None: |
| cache_k, cache_v = kv_cache |
| k = jnp.concatenate([cache_k, k], axis=1) |
| v = jnp.concatenate([cache_v, v], axis=1) |
|
|
| q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.configs[0].num_kv_heads) |
| logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32) |
|
|
| if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]): |
| raise ValueError( |
| f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}" |
| ) |
|
|
| |
| big_neg = -2.3819763e38 |
| masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg) |
|
|
| probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype) |
|
|
| encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v) |
| encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H") |
|
|
| out = [] |
| start = 0 |
| for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)): |
| if x is not None: |
| end = start + x.shape[1] |
| out_einsum = lora.Einsum( |
| shape=(config.num_heads, config.head_dim, config.width), |
| name=_name("attn_vec_einsum", i), |
| init_fn=nn.initializers.lecun_normal(in_axis=(-3, -2), out_axis=-1), |
| lora_config=config.lora_configs.get("attn"), |
| ) |
| out.append(out_einsum("BTNH,NHD->BTD", encoded[:, start:end])) |
| start = end |
| else: |
| out.append(None) |
|
|
| return out, (k, v) |
|
|
|
|
| @at.typecheck |
| class FeedForward(nn.Module): |
| """Feed forward module.""" |
|
|
| features: int |
| hidden_dim: int |
|
|
| @nn.compact |
| def __call__(self, x): |
| dtype = x.dtype |
| w_gating = self.param( |
| "gating_einsum", |
| nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), |
| (2, self.features, self.hidden_dim), |
| ).astype(dtype) |
| ff_gate = jnp.dot(x, w_gating[0]) |
| gate_value = nn.gelu(ff_gate) |
|
|
| ff1 = jnp.dot(x, w_gating[1]) |
| activations = gate_value * ff1 |
|
|
| w_linear = self.param( |
| "linear", |
| nn.initializers.lecun_normal(in_axis=-2, out_axis=-1), |
| (self.hidden_dim, self.features), |
| ).astype(dtype) |
| outputs = jnp.dot(activations, w_linear) |
| assert outputs.dtype == dtype |
| return outputs |
|
|
|
|
| @at.typecheck |
| class Block(nn.Module): |
| """Transformer block.""" |
|
|
| configs: Sequence[Config] |
|
|
| dropout: float = 0.0 |
| dropout_bdims: tuple[int, ...] = () |
|
|
| @nn.compact |
| def __call__(self, xs, kv_cache, positions, attn_mask, decode, deterministic=True): |
| xs = sharding.activation_sharding_constraint(xs) |
| drop = nn.Dropout(self.dropout, self.dropout_bdims) if self.dropout else lambda x, _: x |
|
|
| attn = Attention(configs=self.configs, name="attn") |
|
|
| pre_attn = [] |
| for i, x in enumerate(xs): |
| if x is not None: |
| x = RMSNorm(name=_name("pre_attention_norm", i))(x) |
| pre_attn.append(x) |
|
|
| pre_attn = sharding.activation_sharding_constraint(pre_attn) |
| post_attn, kv_cache = attn(pre_attn, positions, attn_mask, kv_cache) |
| post_attn = jax.tree.map(lambda x: drop(x, deterministic), post_attn) |
| post_attn = sharding.activation_sharding_constraint(post_attn) |
| xs = jax.tree.map(lambda x, y: x + y, xs, post_attn) |
| xs = sharding.activation_sharding_constraint(xs) |
|
|
| out = [] |
| for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)): |
| if x is not None: |
| x = RMSNorm(name=_name("pre_ffw_norm", i))(x) |
| x = lora.FeedForward( |
| features=config.width, |
| hidden_dim=config.mlp_dim, |
| name=_name("mlp", i), |
| lora_config=config.lora_configs.get("ffn"), |
| )(x) |
| out.append(x) |
|
|
| out = sharding.activation_sharding_constraint(out) |
|
|
| out = jax.tree.map(lambda x: drop(x, deterministic), out) |
| xs = jax.tree.map(lambda x, y: x + y, xs, out) |
| xs = sharding.activation_sharding_constraint(xs) |
|
|
| return xs, kv_cache |
|
|
|
|
| KVCache: TypeAlias = tuple[at.Float[at.Array, "l b _t _k _h"], at.Float[at.Array, "l b _t _v _h"]] |
|
|
|
|
| @at.typecheck |
| class Module(nn.Module): |
| """Transformer model, supporting a mixture of different weights for different tokens.""" |
|
|
| configs: Sequence[Config] |
| embed_dtype: str |
|
|
| dropout: float = 0.0 |
| dropout_bdims: tuple[int, ...] = () |
|
|
| def setup(self): |
| |
| assert all(config.depth == self.configs[0].depth for config in self.configs) |
|
|
| self.embedder = Embedder( |
| vocab_size=PALIGEMMA_VOCAB_SIZE, |
| embed_dim=self.configs[0].width, |
| name="embedder", |
| ) |
| block_cls = nn.remat( |
| Block, |
| prevent_cse=False, |
| static_argnums=(5,), |
| policy=jax.checkpoint_policies.nothing_saveable, |
| ) |
| self.layers = nn.scan( |
| block_cls, |
| variable_axes={"params": 0}, |
| split_rngs={"params": True, "dropout": True}, |
| in_axes=(0, nn.broadcast, nn.broadcast, nn.broadcast), |
| length=self.configs[0].depth, |
| )( |
| configs=self.configs, |
| dropout=self.dropout, |
| dropout_bdims=self.dropout_bdims, |
| ) |
| self.final_norms = [RMSNorm(name=_name("final_norm", i)) for i in range(len(self.configs))] |
|
|
| @at.typecheck |
| def embed(self, tokens: at.Int[at.Array, "b t"]) -> at.Float[at.Array, "b t d"]: |
| return self.embedder.encode(tokens).astype(self.embed_dtype) |
|
|
| @at.typecheck |
| def __call__( |
| self, |
| |
| embedded: Sequence[at.Float[at.Array, "b _t _d"] | None], |
| positions: at.Int[at.Array, "b t"], |
| mask: at.Bool[at.Array, "b t s"], |
| *, |
| kv_cache: KVCache | None = None, |
| deterministic: bool = True, |
| ) -> tuple[Sequence[at.Float[at.Array, "b _t _d"] | None], KVCache]: |
| embedded = jax.tree.map(lambda e: e.astype(self.embed_dtype), embedded) |
| mask = jnp.asarray(mask)[:, None, :, :] |
|
|
| embedded, kv_cache = self.layers(embedded, kv_cache, positions, mask, deterministic) |
|
|
| assert all(e.dtype == jnp.dtype(self.embed_dtype) for e in embedded if e is not None) |
|
|
| return [f(e) if e is not None else e for f, e in zip(self.final_norms, embedded, strict=True)], kv_cache |
|
|
| def init(self): |
| """Convenience method for initializing all parameters, necessary due to the quirks of linen.""" |
| self.embed(jnp.zeros((1, 1), dtype=jnp.int32)) |
| self( |
| [jnp.zeros((1, 1, c.width)) for c in self.configs], |
| jnp.zeros((1, len(self.configs)), dtype=jnp.int32), |
| jnp.zeros((1, len(self.configs), len(self.configs)), dtype=bool), |
| ) |
|
|
|
|
| def _apply_rope(x, *, positions, max_wavelength=10_000): |
| """Applies RoPE positions [B, L] to x [B, L, H, D].""" |
| freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32) |
| timescale = max_wavelength**freq_exponents |
| radians = positions[..., None] / timescale[None, None, :] |
| radians = radians[..., None, :] |
| assert radians.dtype == jnp.float32 |
| |
| sin, cos = jnp.sin(radians), jnp.cos(radians) |
| x1, x2 = jnp.split(x, 2, axis=-1) |
| res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) |
| assert res.dtype == jnp.float32 |
| |
| |
| |
| |
| return res.astype(x.dtype) |
|
|
|
|
| def _name(name, i): |
| |
| |
| |
| |
| if i == 0: |
| return name |
| return f"{name}_{i}" |
|
|