| from dataclasses import dataclass |
| from typing import Literal |
|
|
| import torch |
| from torch.nn import functional as F |
|
|
|
|
| def gelu_approx(x): |
| return F.gelu(x, approximate="tanh") |
|
|
|
|
| @dataclass |
| class LinearWeights: |
| weight: torch.Tensor |
| bias: torch.Tensor |
|
|
|
|
| def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor: |
| return F.linear(x, w.weight, w.bias) |
|
|
|
|
| @dataclass |
| class LayerNormWeights: |
| weight: torch.Tensor |
| bias: torch.Tensor |
|
|
|
|
| def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor: |
| return F.layer_norm(x, w.bias.shape, w.weight, w.bias) |
|
|
|
|
| @dataclass |
| class MLPWeights: |
| fc1: LinearWeights |
| fc2: LinearWeights |
| act: Literal["gelu_approx"] = "gelu_approx" |
|
|
|
|
| def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor: |
| x = linear(x, w.fc1) |
| x = gelu_approx(x) |
| x = linear(x, w.fc2) |
| return x |
|
|
|
|
| @dataclass |
| class AttentionWeights: |
| qkv: LinearWeights |
| proj: LinearWeights |
|
|
|
|
| def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor: |
| bsz, q_len, d_model = x.shape |
| head_dim = d_model // n_heads |
|
|
| q, k, v = [ |
| t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) |
| for t in linear(x, w.qkv).chunk(3, dim=-1) |
| ] |
| out = F.scaled_dot_product_attention(q, k, v) |
| out = out.transpose(1, 2).reshape(bsz, q_len, d_model) |
| out = linear(out, w.proj) |
| return out |
|
|