Create README.md
Browse files
README.md
ADDED
|
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
# RetNPhi: Byte-Level Hybrid of Phi-3.5 and RetNet
|
| 3 |
+
|
| 4 |
+
RetNPhi is an experimental architecture that transforms Phi-3.5 into a byte-level language model, incorporating RetNet-inspired mechanisms. This innovative approach enables the model to process raw byte sequences, allowing for universal file type handling.
|
| 5 |
+
|
| 6 |
+
## Key Features:
|
| 7 |
+
|
| 8 |
+
1. **Byte-Level Processing**: Operates directly on raw byte sequences, enabling universal application to any file type.
|
| 9 |
+
2. **RetNet Integration**: Incorporates RetNet's multi-scale exponential decay and group normalization for efficient long-range dependency modeling.
|
| 10 |
+
3. **Dual-mode Processing**: Supports parallel mode for efficient training and recurrent mode for inference.
|
| 11 |
+
4. **Selective Fine-tuning**: Trains only specific layers (e.g., token embedding, post-attention layer normalizations) while keeping most of the original Phi-3.5 weights frozen.
|
| 12 |
+
5. **Weight-Decomposed Low-Rank Adaptation (DoRA)**: Applies DoRA to self-attention output projections for efficient adaptation while preserving pretrained knowledge.
|
| 13 |
+
|
| 14 |
+
## Implementation Strategy:
|
| 15 |
+
|
| 16 |
+
- **Weight Reuse**: Utilizes frozen weights from the original Phi-3.5 model for most layers.
|
| 17 |
+
- **Flexible DoRA Application**: Allows configuration of which layers and targets to apply DoRA.
|
| 18 |
+
- **Configurable Architecture**: Supports both retention-based and original attention mechanisms.
|
| 19 |
+
- **Untied Embeddings Option**: Provides the ability to use separate input and output embeddings.
|
| 20 |
+
|
| 21 |
+
## Training and Inference:
|
| 22 |
+
|
| 23 |
+
- Implements efficient training loops with customizable learning rate schedules.
|
| 24 |
+
- Supports both training from scratch and fine-tuning from a checkpoint.
|
| 25 |
+
- Provides a generation function for text completion tasks.
|
| 26 |
+
|
| 27 |
+
## Goals:
|
| 28 |
+
|
| 29 |
+
- Explore the potential of retention-like mechanisms in a byte-level Phi architecture.
|
| 30 |
+
- Leverage dual-mode processing for efficient training and inference.
|
| 31 |
+
- Develop a universal model capable of processing any file type.
|
| 32 |
+
|
| 33 |
+
Note: This is a highly experimental implementation, designed for research and exploration rather than production use. It demonstrates the potential of combining pretrained models with novel architectures and efficient fine-tuning techniques.
|
| 34 |
+
|
| 35 |
+
Author: Josef Albers
|
| 36 |
+
Date: Aug 28, 2024
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
import glob
|
| 40 |
+
import json
|
| 41 |
+
import math
|
| 42 |
+
import time
|
| 43 |
+
from datetime import datetime
|
| 44 |
+
from types import SimpleNamespace
|
| 45 |
+
|
| 46 |
+
import fire
|
| 47 |
+
import mlx.core as mx
|
| 48 |
+
import mlx.nn as nn
|
| 49 |
+
import mlx.optimizers as optim
|
| 50 |
+
import numpy as np
|
| 51 |
+
from huggingface_hub import snapshot_download
|
| 52 |
+
from mlx.utils import tree_flatten, tree_unflatten
|
| 53 |
+
|
| 54 |
+
from datasets import load_dataset
|
| 55 |
+
|
| 56 |
+
class Tokenizer:
|
| 57 |
+
def __init__(self, file_path=None):
|
| 58 |
+
if file_path is None:
|
| 59 |
+
self.vocab = list(range(256))
|
| 60 |
+
else:
|
| 61 |
+
with open(file_path, 'r') as f:
|
| 62 |
+
content = f.read().lower().encode('utf-8')
|
| 63 |
+
self.vocab = sorted(set(content))
|
| 64 |
+
self.vocab_size = len(self.vocab)
|
| 65 |
+
self.byte_to_index = {byte: index for index, byte in enumerate(self.vocab)}
|
| 66 |
+
self.index_to_byte = {index: byte for index, byte in enumerate(self.vocab)}
|
| 67 |
+
|
| 68 |
+
def encode(self, text):
|
| 69 |
+
byte_seq = text.encode('utf-8')
|
| 70 |
+
return [self.byte_to_index[byte] for byte in byte_seq]
|
| 71 |
+
|
| 72 |
+
def decode(self, indices):
|
| 73 |
+
byte_seq = bytes(self.index_to_byte[index] for index in indices)
|
| 74 |
+
return byte_seq.decode('utf-8', errors='ignore')
|
| 75 |
+
|
| 76 |
+
class SuRoPE(nn.Module):
|
| 77 |
+
def __init__(self, config):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.dim = config.hidden_size // config.num_attention_heads
|
| 80 |
+
self.original_max_position_embeddings = config.original_max_position_embeddings
|
| 81 |
+
self.rope_theta = config.rope_theta
|
| 82 |
+
self.scaling_factor = math.sqrt(1 + math.log(config.max_position_embeddings / config.original_max_position_embeddings) / math.log(config.original_max_position_embeddings))
|
| 83 |
+
self._long_factor = mx.array(config.rope_scaling["long_factor"], dtype=mx.float32)
|
| 84 |
+
self._short_factor = mx.array(config.rope_scaling["short_factor"], dtype=mx.float32)
|
| 85 |
+
|
| 86 |
+
def __call__(self, q, k, position_ids):
|
| 87 |
+
cos, sin = self._get_cos_sin(position_ids)
|
| 88 |
+
q = (q * cos) + (self._rotate_half(q) * sin)
|
| 89 |
+
k = (k * cos) + (self._rotate_half(k) * sin)
|
| 90 |
+
return q, k
|
| 91 |
+
|
| 92 |
+
def _get_cos_sin(self, position_ids):
|
| 93 |
+
su_factor = self._short_factor
|
| 94 |
+
position_ids_expanded = position_ids[:, None, :]
|
| 95 |
+
inv_freq = 1.0 / (su_factor * self.rope_theta**(mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim))
|
| 96 |
+
inv_freq_expanded = mx.repeat(inv_freq[None, :, None], position_ids.shape[0], axis=0)
|
| 97 |
+
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1)
|
| 98 |
+
emb = mx.concatenate([freqs, freqs], axis=-1)
|
| 99 |
+
cos = mx.expand_dims(mx.cos(emb) * self.scaling_factor, axis=1)
|
| 100 |
+
sin = mx.expand_dims(mx.sin(emb) * self.scaling_factor, axis=1)
|
| 101 |
+
return cos, sin
|
| 102 |
+
|
| 103 |
+
def _rotate_half(self, x):
|
| 104 |
+
midpoint = x.shape[-1] // 2
|
| 105 |
+
x1, x2 = x[..., :midpoint], x[..., midpoint:]
|
| 106 |
+
return mx.concatenate([-x2, x1], axis=-1)
|
| 107 |
+
|
| 108 |
+
class Phi3Attention(nn.Module):
|
| 109 |
+
def __init__(self, config):
|
| 110 |
+
super().__init__()
|
| 111 |
+
dim = config.hidden_size
|
| 112 |
+
self.n_heads = n_heads = config.num_attention_heads
|
| 113 |
+
self.n_kv_heads = n_kv_heads = config.num_key_value_heads
|
| 114 |
+
self.num_hidden_layers = config.num_hidden_layers
|
| 115 |
+
self.head_dim = head_dim = config.hidden_size // n_heads
|
| 116 |
+
self.scale = head_dim**-0.5
|
| 117 |
+
chop_1 = self.n_heads * self.head_dim
|
| 118 |
+
chop_2 = chop_1 + self.n_kv_heads * self.head_dim
|
| 119 |
+
self.chop = [chop_1, chop_2]
|
| 120 |
+
op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
|
| 121 |
+
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
|
| 122 |
+
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
| 123 |
+
self.rope = SuRoPE(config)
|
| 124 |
+
|
| 125 |
+
def __call__(self, x, position_ids, attention_mask, cache, use_recurrent_mode):
|
| 126 |
+
B, L, _ = x.shape
|
| 127 |
+
qkv = self.qkv_proj(x)
|
| 128 |
+
q, k, v = mx.split(qkv, self.chop, axis=-1)
|
| 129 |
+
q = q.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
| 130 |
+
k = k.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
| 131 |
+
v = v.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
| 132 |
+
if cache is None:
|
| 133 |
+
position_ids = mx.arange(q.shape[2], dtype=mx.float32)[None] if position_ids is None else position_ids
|
| 134 |
+
q, k = self.rope(q,k,position_ids)
|
| 135 |
+
mask = mx.triu(mx.full((v.shape[2], v.shape[2]), -mx.inf), k=1)
|
| 136 |
+
if attention_mask is not None:
|
| 137 |
+
mask += mx.where(attention_mask[:, :, None]*attention_mask[:, None, :]==1, 0, -mx.inf)
|
| 138 |
+
mask = mx.expand_dims(mask, 1)
|
| 139 |
+
else:
|
| 140 |
+
mask = mask[None, None]
|
| 141 |
+
else:
|
| 142 |
+
past_k, past_v, past_p, past_m = cache
|
| 143 |
+
position_ids = past_p[:,-1:]+1
|
| 144 |
+
mask = mx.pad(past_m[:,:,-1:,:], ((0,0),(0,0),(0,0),(0,1)))
|
| 145 |
+
q, k = self.rope(q, k, position_ids)
|
| 146 |
+
k = mx.concatenate([past_k, k], axis=2)
|
| 147 |
+
v = mx.concatenate([past_v, v], axis=2)
|
| 148 |
+
cache = (k, v, position_ids, mask)
|
| 149 |
+
w = (q * self.scale) @ k.transpose(0, 1, 3, 2)
|
| 150 |
+
w += mask
|
| 151 |
+
w = mx.softmax(w, axis=-1)
|
| 152 |
+
o = w @ v
|
| 153 |
+
o = o.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
| 154 |
+
return self.o_proj(o).astype(x.dtype), cache
|
| 155 |
+
|
| 156 |
+
class Phi3Retention(nn.Module):
|
| 157 |
+
def __init__(self, config):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.dim = dim = config.hidden_size
|
| 160 |
+
self.n_heads = n_heads = config.num_attention_heads
|
| 161 |
+
self.n_kv_heads = n_kv_heads = config.num_key_value_heads
|
| 162 |
+
self.num_hidden_layers = config.num_hidden_layers
|
| 163 |
+
self.head_dim = head_dim = config.hidden_size // n_heads
|
| 164 |
+
self.scale = head_dim**-0.5
|
| 165 |
+
chop_1 = self.n_heads * self.head_dim
|
| 166 |
+
chop_2 = chop_1 + self.n_kv_heads * self.head_dim
|
| 167 |
+
self.chop = [chop_1, chop_2]
|
| 168 |
+
op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
|
| 169 |
+
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
|
| 170 |
+
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
| 171 |
+
self.rope = SuRoPE(config)
|
| 172 |
+
xmin, xmax = math.log(1 / 32), math.log(1 / 512)
|
| 173 |
+
x = mx.linspace(xmin, xmax, num=n_heads)
|
| 174 |
+
self._gamma = 1 - x.exp()
|
| 175 |
+
self.gn = nn.GroupNorm(num_groups=head_dim, dims=-1, affine=False)
|
| 176 |
+
|
| 177 |
+
def __call__(self, x, position_ids, attention_mask, cache, use_recurrent_mode):
|
| 178 |
+
if use_recurrent_mode:
|
| 179 |
+
return self.recurrent_mode(x, cache)
|
| 180 |
+
B, L, _ = x.shape
|
| 181 |
+
qkv = self.qkv_proj(x)
|
| 182 |
+
q, k, v = mx.split(qkv, self.chop, axis=-1)
|
| 183 |
+
q = q.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
| 184 |
+
k = k.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
| 185 |
+
v = v.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
| 186 |
+
position_ids = mx.arange(q.shape[2], dtype=mx.float32)[None] if position_ids is None else position_ids
|
| 187 |
+
q, k = self.rope(q,k,position_ids)
|
| 188 |
+
cache = None
|
| 189 |
+
w = (q * self.scale) @ k.transpose(0, 1, 3, 2)
|
| 190 |
+
w = w * self._decay(L)
|
| 191 |
+
o = w @ v
|
| 192 |
+
o = o.transpose(0, 2, 1, 3).reshape(B*L, -1)
|
| 193 |
+
o = self.gn(o).reshape(B, L, -1)
|
| 194 |
+
return self.o_proj(o).astype(x.dtype), cache
|
| 195 |
+
|
| 196 |
+
def recurrent_mode(self, x, cache):
|
| 197 |
+
if cache is None:
|
| 198 |
+
s = mx.zeros((1, 32, 96, 96))
|
| 199 |
+
n = 0
|
| 200 |
+
else:
|
| 201 |
+
s, n = cache
|
| 202 |
+
qkv = self.qkv_proj(x)
|
| 203 |
+
q, k, v = mx.split(qkv, self.chop, axis=-1)
|
| 204 |
+
q = q.reshape(1, 1, self.n_heads, -1).transpose(0, 2, 1, 3)
|
| 205 |
+
k = k.reshape(1, 1, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
| 206 |
+
v = v.reshape(1, 1, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
| 207 |
+
position_ids = mx.array([[n]])
|
| 208 |
+
q, k = self.rope(q,k,position_ids)
|
| 209 |
+
k = k * self.scale
|
| 210 |
+
s = self._gamma[None, :, None, None] * s + (k.transpose(0, 1, 3, 2) @ v)
|
| 211 |
+
o = q @ s
|
| 212 |
+
o = o.transpose(0, 2, 1, 3).reshape(1, -1)
|
| 213 |
+
o = self.gn(o).reshape(1, 1, -1)
|
| 214 |
+
o = self.o_proj(o).astype(x.dtype)
|
| 215 |
+
return o, (s, n+1)
|
| 216 |
+
|
| 217 |
+
def _decay(self, sequence_length):
|
| 218 |
+
n = mx.arange(sequence_length)[:,None]
|
| 219 |
+
m = mx.arange(sequence_length)[None]
|
| 220 |
+
D = (self._gamma[:, None, None] ** (n-m)) * (n >= m)
|
| 221 |
+
return D
|
| 222 |
+
|
| 223 |
+
class Phi3MLP(nn.Module):
|
| 224 |
+
def __init__(self, config):
|
| 225 |
+
super().__init__()
|
| 226 |
+
self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
|
| 227 |
+
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
| 228 |
+
|
| 229 |
+
def __call__(self, x):
|
| 230 |
+
x = self.gate_up_proj(x)
|
| 231 |
+
gate, x = mx.split(x, 2, axis=-1)
|
| 232 |
+
return self.down_proj(nn.silu(gate) * x)
|
| 233 |
+
|
| 234 |
+
class Phi3DecoderLayer(nn.Module):
|
| 235 |
+
def __init__(self, config):
|
| 236 |
+
super().__init__()
|
| 237 |
+
if config.use_retention:
|
| 238 |
+
self.self_attn = Phi3Retention(config)
|
| 239 |
+
else:
|
| 240 |
+
self.self_attn = Phi3Attention(config)
|
| 241 |
+
self.mlp = Phi3MLP(config)
|
| 242 |
+
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 243 |
+
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 244 |
+
|
| 245 |
+
def __call__(self, x, position_ids, attention_mask, cache, use_recurrent_mode):
|
| 246 |
+
r, cache = self.self_attn(self.input_layernorm(x), position_ids, attention_mask, cache, use_recurrent_mode)
|
| 247 |
+
h = x + r
|
| 248 |
+
r = self.mlp(self.post_attention_layernorm(h))
|
| 249 |
+
return h + r, cache
|
| 250 |
+
|
| 251 |
+
class Phi3Model(nn.Module):
|
| 252 |
+
def __init__(self, config):
|
| 253 |
+
super().__init__()
|
| 254 |
+
self.embed_new = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 255 |
+
self.layers = [Phi3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
|
| 256 |
+
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 257 |
+
|
| 258 |
+
def __call__(self, input_ids, pixel_values, image_sizes, position_ids, attention_mask, cache, use_recurrent_mode):
|
| 259 |
+
x = self.embed_new(input_ids)
|
| 260 |
+
cache = [None]*len(self.layers) if cache is None else cache
|
| 261 |
+
for i, l in enumerate(self.layers):
|
| 262 |
+
x, cache[i] = l(x, position_ids, attention_mask, cache[i], use_recurrent_mode)
|
| 263 |
+
return self.norm(x), cache
|
| 264 |
+
|
| 265 |
+
class Phi3ForCausalLM(nn.Module):
|
| 266 |
+
def __init__(self, config):
|
| 267 |
+
super().__init__()
|
| 268 |
+
self.model = Phi3Model(config)
|
| 269 |
+
if config.untie_embedding:
|
| 270 |
+
self.lm_new = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 271 |
+
self.untie = True
|
| 272 |
+
else:
|
| 273 |
+
self.untie = False
|
| 274 |
+
|
| 275 |
+
def __call__(self, input_ids, pixel_values=None, image_sizes=None, position_ids=None, attention_mask=None, cache=None, use_recurrent_mode=False):
|
| 276 |
+
x, cache = self.model(input_ids, pixel_values, image_sizes, position_ids, attention_mask, cache, use_recurrent_mode)
|
| 277 |
+
if self.untie:
|
| 278 |
+
return self.lm_new(x), cache
|
| 279 |
+
return self.model.embed_new.as_linear(x), cache
|
| 280 |
+
|
| 281 |
+
@property
|
| 282 |
+
def layers(self):
|
| 283 |
+
return self.model.layers
|
| 284 |
+
|
| 285 |
+
class DoRALinear(nn.Module):
|
| 286 |
+
@staticmethod
|
| 287 |
+
def from_linear(linear, r, alpha, scale, dropout):
|
| 288 |
+
output_dims, input_dims = linear.weight.shape
|
| 289 |
+
if isinstance(linear, nn.QuantizedLinear):
|
| 290 |
+
input_dims *= 32 // linear.bits
|
| 291 |
+
lora_lin = DoRALinear(input_dims=input_dims, output_dims=output_dims, r=r, alpha=alpha, scale=scale, dropout=dropout)
|
| 292 |
+
lora_lin.linear = linear
|
| 293 |
+
return lora_lin
|
| 294 |
+
|
| 295 |
+
def __init__(self, input_dims, output_dims, r, alpha, scale, dropout, bias=False):
|
| 296 |
+
super().__init__()
|
| 297 |
+
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
| 298 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 299 |
+
self.scale = scale * (alpha / r)
|
| 300 |
+
scale = 1 / math.sqrt(input_dims)
|
| 301 |
+
self.lora_a = mx.random.uniform(low=-scale, high=scale, shape=(input_dims, r))
|
| 302 |
+
self.lora_b = mx.zeros(shape=(r, output_dims))
|
| 303 |
+
self.m = mx.linalg.norm(self._dequantized_weight(), axis=1).astype(mx.float32)
|
| 304 |
+
|
| 305 |
+
def _dequantized_weight(self):
|
| 306 |
+
weight = self.linear.weight
|
| 307 |
+
if isinstance(self.linear, nn.QuantizedLinear):
|
| 308 |
+
weight = mx.dequantize(weight, self.linear.scales, self.linear.biases, self.linear.group_size, self.linear.bits)
|
| 309 |
+
return weight
|
| 310 |
+
|
| 311 |
+
def __call__(self, x):
|
| 312 |
+
y = self.linear(x)
|
| 313 |
+
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
| 314 |
+
z = y + (self.scale * z)
|
| 315 |
+
adapted = self._dequantized_weight() + (self.scale * self.lora_b.T) @ self.lora_a.T
|
| 316 |
+
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
|
| 317 |
+
z = (self.m / denom) * z
|
| 318 |
+
return z.astype(x.dtype)
|
| 319 |
+
|
| 320 |
+
def linear_to_lora_layers(model, lora_layers, lora_targets, lora_rank, lora_scale, lora_dropout):
|
| 321 |
+
if lora_layers == 'all':
|
| 322 |
+
lora_layers = model.layers
|
| 323 |
+
elif isinstance(lora_layers, int):
|
| 324 |
+
lora_layers = model.layers[-lora_layers:]
|
| 325 |
+
elif isinstance(lora_layers, list):
|
| 326 |
+
lora_layers = [model.layers[i] for i in lora_layers]
|
| 327 |
+
else:
|
| 328 |
+
raise ValueError("Invalid type for lora_layers. Expected int (number of layers) or list (layer indices or names).")
|
| 329 |
+
def to_lora(layer):
|
| 330 |
+
return DoRALinear.from_linear(layer, r=lora_rank, alpha=lora_rank, scale=lora_scale, dropout=lora_dropout)
|
| 331 |
+
for l in lora_layers:
|
| 332 |
+
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in lora_targets]
|
| 333 |
+
l.update_modules(tree_unflatten(lora_layers))
|
| 334 |
+
|
| 335 |
+
def load_base_model(model_cfg, init=False):
|
| 336 |
+
model_id='microsoft/Phi-3.5-mini-instruct'
|
| 337 |
+
model_path = snapshot_download(model_id, allow_patterns=["*.safetensors", "config.json"])
|
| 338 |
+
with open(f"{model_path}/config.json", "r") as f:
|
| 339 |
+
config = json.load(f)
|
| 340 |
+
config = config|model_cfg
|
| 341 |
+
model_config = SimpleNamespace(**config)
|
| 342 |
+
model = Phi3ForCausalLM(model_config)
|
| 343 |
+
model_weight = [(k, v) for wf in glob.glob(f"{model_path}/*.safetensors") for k, v in mx.load(wf).items()]
|
| 344 |
+
model.load_weights(model_weight, strict=False)
|
| 345 |
+
model.set_dtype(mx.float32)
|
| 346 |
+
if init:
|
| 347 |
+
init_fn_embed = nn.init.normal(mean=-0.000453949, std=0.0344238)
|
| 348 |
+
model.apply_to_modules(lambda k, v: v.apply(init_fn_embed) if k.endswith('embed_new') else None)
|
| 349 |
+
if model_config.untie_embedding:
|
| 350 |
+
init_fn_lm = nn.init.normal(mean=-0.000231743, std=0.043457)
|
| 351 |
+
model.apply_to_modules(lambda k, v: v.apply(init_fn_lm) if k.endswith('lm_new') else None)
|
| 352 |
+
class_predicate = lambda k, m: hasattr(m, "to_quantized") and not k.endswith('new')
|
| 353 |
+
nn.quantize(model, 64, 4, class_predicate)
|
| 354 |
+
mx.eval(model.parameters())
|
| 355 |
+
return model
|
| 356 |
+
|
| 357 |
+
def load_model_for_training(lora_cfg, model_cfg, thaws, from_path=None):
|
| 358 |
+
model = load_base_model(model_cfg, init=False)
|
| 359 |
+
if from_path:
|
| 360 |
+
model.load_weights(from_path, strict=False)
|
| 361 |
+
model.freeze()
|
| 362 |
+
if len(lora_cfg['targets']) > 1:
|
| 363 |
+
linear_to_lora_layers(model, lora_layers=lora_cfg['layers'], lora_targets=lora_cfg['targets'], lora_rank=lora_cfg['rank'], lora_scale=lora_cfg['scale'], lora_dropout=lora_cfg['dropout'])
|
| 364 |
+
model.apply_to_modules(lambda k, v: v.unfreeze() if any(k.endswith(t) for t in thaws) else None)
|
| 365 |
+
mx.eval(model.parameters())
|
| 366 |
+
# print("Trainable parameters:", [i[0] for i in tree_flatten(model.trainable_parameters())])
|
| 367 |
+
model.train()
|
| 368 |
+
return model
|
| 369 |
+
|
| 370 |
+
def load_model_for_inference(lora_cfg, model_cfg):
|
| 371 |
+
model = load_base_model(model_cfg, init=False)
|
| 372 |
+
if len(lora_cfg['targets']) > 1:
|
| 373 |
+
linear_to_lora_layers(model, lora_layers=lora_cfg['layers'], lora_targets=lora_cfg['targets'], lora_rank=lora_cfg['rank'], lora_scale=lora_cfg['scale'], lora_dropout=lora_cfg['dropout'])
|
| 374 |
+
_path = 'trained_retnphi.safetensors' if model_cfg['use_retention'] else 'trained_orgnphi.safetensors'
|
| 375 |
+
model.load_weights(_path, strict=False)
|
| 376 |
+
mx.eval(model.parameters())
|
| 377 |
+
model.eval()
|
| 378 |
+
return model
|
| 379 |
+
|
| 380 |
+
def generate(prompt, lora_cfg, model_cfg, max_tokens=50, verbose = True):
|
| 381 |
+
model = load_model_for_inference(lora_cfg=lora_cfg, model_cfg=model_cfg)
|
| 382 |
+
input_ids = mx.array(tokenizer.encode(prompt))
|
| 383 |
+
if model_cfg['use_retention']:
|
| 384 |
+
cache = None
|
| 385 |
+
for i in input_ids:
|
| 386 |
+
logits, cache = model(i[None, None], cache=cache, use_recurrent_mode=True)
|
| 387 |
+
else:
|
| 388 |
+
logits, cache = model(input_ids[None])
|
| 389 |
+
token = mx.argmax(logits[:,-1,:], axis=-1)
|
| 390 |
+
mx.eval(token, cache)
|
| 391 |
+
list_tokens = token.tolist()
|
| 392 |
+
for i in range(max_tokens):
|
| 393 |
+
logits, cache = model(token[None], cache=cache, use_recurrent_mode=True)
|
| 394 |
+
token = mx.argmax(logits[:,-1,:], axis=-1)
|
| 395 |
+
mx.eval(token, cache)
|
| 396 |
+
list_tokens += token.tolist()
|
| 397 |
+
if tokenizer.decode(list_tokens[-2:]) == '\n\n':
|
| 398 |
+
break
|
| 399 |
+
output = tokenizer.decode(list_tokens)
|
| 400 |
+
if verbose:
|
| 401 |
+
print(f'{prompt=} + {output=}\n-> {prompt+output}')
|
| 402 |
+
del model
|
| 403 |
+
return output
|
| 404 |
+
|
| 405 |
+
def train_gsm(learning_rates, num_epochs, batch_size, seq_length, lora_cfg, model_cfg, thaws, take, from_path=None):
|
| 406 |
+
def load_gsm_data(tokenizer, is_tiny=True):
|
| 407 |
+
if is_tiny:
|
| 408 |
+
data = load_dataset("TinyGSM/TinyGSM")["train"]
|
| 409 |
+
if take:
|
| 410 |
+
data = data.take(take)
|
| 411 |
+
data = data.filter(lambda x: len(x['question']) < 100 and ':' not in x['question'] and '-' not in x['question'] and "'" not in x['code'] and '\n result =' in x['code'])
|
| 412 |
+
split_point = int(len(data) * 0.8)
|
| 413 |
+
train_data = data.select(range(split_point))
|
| 414 |
+
eval_data = data.select(range(split_point, len(data)))
|
| 415 |
+
def format_example(example):
|
| 416 |
+
code_raw = example['code']
|
| 417 |
+
start = code_raw.rfind('\n """')
|
| 418 |
+
if start == -1:
|
| 419 |
+
print('Wrong format to start')
|
| 420 |
+
return code_raw.strip()
|
| 421 |
+
start = start + 8
|
| 422 |
+
end = code_raw.rfind('\n result =')
|
| 423 |
+
if end == -1:
|
| 424 |
+
print('Wrong format to end')
|
| 425 |
+
end = len(code_raw)
|
| 426 |
+
code_block = code_raw[start:end]
|
| 427 |
+
code_lines = code_block.split('\n ')
|
| 428 |
+
formatted_code = '\n'.join(line.rstrip() for line in code_lines if line.strip())
|
| 429 |
+
formatted_code = '\n' + formatted_code.strip() + '\n\n'
|
| 430 |
+
result = (example['question'].strip(), formatted_code)
|
| 431 |
+
return result
|
| 432 |
+
else:
|
| 433 |
+
dataset = load_dataset("openai/gsm8k", "main")
|
| 434 |
+
train_data = dataset["train"]
|
| 435 |
+
eval_data = dataset["test"]
|
| 436 |
+
def format_example(example):
|
| 437 |
+
return (example['question'].strip(), '\n'+example['answer'].strip()+'\n\n')
|
| 438 |
+
train_formatted = [format_example(ex) for ex in train_data]
|
| 439 |
+
eval_formatted = [format_example(ex) for ex in eval_data]
|
| 440 |
+
return train_formatted, eval_formatted
|
| 441 |
+
|
| 442 |
+
def create_batches(data, tokenizer, batch_size, seq_length):
|
| 443 |
+
def _encode(x):
|
| 444 |
+
return [tokenizer.encode(i) for i in x]
|
| 445 |
+
encoded_data = [_encode(x) for x in data]
|
| 446 |
+
encoded_data = [x for x in encoded_data if len(x[0]+x[1]) <= seq_length+1]
|
| 447 |
+
if batch_size is None:
|
| 448 |
+
batch_size = min(len(encoded_data), 64)
|
| 449 |
+
else:
|
| 450 |
+
encoded_data = encoded_data[:(len(encoded_data) // batch_size) * batch_size]
|
| 451 |
+
np.random.shuffle(encoded_data)
|
| 452 |
+
for i in range(0, len(encoded_data), batch_size):
|
| 453 |
+
batch = encoded_data[i:i+batch_size]
|
| 454 |
+
max_len = min(max(len(q+a)-1 for q, a in batch), seq_length)
|
| 455 |
+
x_batch = []
|
| 456 |
+
y_batch = []
|
| 457 |
+
mask_batch = []
|
| 458 |
+
for q, a in batch:
|
| 459 |
+
combined = (q+a)[:max_len+1]
|
| 460 |
+
x = combined[:-1]
|
| 461 |
+
y = combined[1:]
|
| 462 |
+
pad_length = max_len - len(x)
|
| 463 |
+
x = x + [0] * pad_length
|
| 464 |
+
y = y + [0] * pad_length
|
| 465 |
+
mask = [False] * (len(q)-1) + [True] * (len(a)) + [False] * (pad_length)
|
| 466 |
+
x_batch.append(x)
|
| 467 |
+
y_batch.append(y)
|
| 468 |
+
mask_batch.append(mask)
|
| 469 |
+
yield mx.array(x_batch), mx.array(y_batch), mx.array(mask_batch)
|
| 470 |
+
|
| 471 |
+
def loss_fn(model, X, y, mask):
|
| 472 |
+
logits, _ = model(X)
|
| 473 |
+
logits = logits.astype(mx.float32)
|
| 474 |
+
ce = nn.losses.cross_entropy(logits, y, reduction='none')
|
| 475 |
+
masked_loss = ce * mask
|
| 476 |
+
return masked_loss.sum(), mask.sum()
|
| 477 |
+
|
| 478 |
+
def evaluate(model, data, tokenizer, seq_length):
|
| 479 |
+
model.eval()
|
| 480 |
+
total_loss = 0
|
| 481 |
+
total_samples = 0
|
| 482 |
+
for X, y, mask in create_batches(data, tokenizer, None, seq_length):
|
| 483 |
+
loss, ntoks = loss_fn(model, X, y, mask)
|
| 484 |
+
total_loss += loss.item()
|
| 485 |
+
total_samples += ntoks.item()
|
| 486 |
+
return total_loss / total_samples if total_samples > 0 else -1
|
| 487 |
+
|
| 488 |
+
def get_optimizer(train_data):
|
| 489 |
+
num_batches_per_epoch = len(list(create_batches(train_data, tokenizer, batch_size, seq_length)))
|
| 490 |
+
print(f'{num_batches_per_epoch=}')
|
| 491 |
+
num_steps = num_epochs * num_batches_per_epoch
|
| 492 |
+
num_warmup = num_steps // 10
|
| 493 |
+
max_lr, min_lr = learning_rates
|
| 494 |
+
if num_warmup > 2:
|
| 495 |
+
warmup = optim.linear_schedule(min_lr*0.1, max_lr, steps=num_warmup)
|
| 496 |
+
cosine = optim.cosine_decay(max_lr, num_steps - num_warmup, min_lr)
|
| 497 |
+
lr_schedule = optim.join_schedules([warmup, cosine], [num_warmup])
|
| 498 |
+
else:
|
| 499 |
+
lr_schedule = optim.cosine_decay(max_lr, num_steps, min_lr)
|
| 500 |
+
return optim.Lion(learning_rate=lr_schedule), num_steps
|
| 501 |
+
|
| 502 |
+
for arg_name in sorted(locals()):
|
| 503 |
+
if arg_name != 'self':
|
| 504 |
+
arg_value = locals()[arg_name]
|
| 505 |
+
if not callable(arg_value):
|
| 506 |
+
print(f"{arg_name}: {arg_value}")
|
| 507 |
+
|
| 508 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 509 |
+
print(f'--- {timestamp} ---')
|
| 510 |
+
train_data, eval_data = load_gsm_data(tokenizer=tokenizer)
|
| 511 |
+
model = load_model_for_training(lora_cfg=lora_cfg, model_cfg=model_cfg, thaws=thaws)
|
| 512 |
+
optimizer, num_steps = get_optimizer(train_data)
|
| 513 |
+
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
| 514 |
+
mx.eval(model, optimizer)
|
| 515 |
+
metrics = {
|
| 516 |
+
'steps': [],
|
| 517 |
+
'learning_rates': [],
|
| 518 |
+
'all_train_losses': [],
|
| 519 |
+
'avg_train_losses': [],
|
| 520 |
+
'val_losses': [],
|
| 521 |
+
'trained_toks': [],
|
| 522 |
+
}
|
| 523 |
+
step = 0
|
| 524 |
+
trained_toks = 0
|
| 525 |
+
losses = []
|
| 526 |
+
tic = time.perf_counter()
|
| 527 |
+
for epoch in range(num_epochs):
|
| 528 |
+
for X, y, loss_mask in create_batches(data=train_data, tokenizer=tokenizer, batch_size=batch_size, seq_length=seq_length):
|
| 529 |
+
model.train()
|
| 530 |
+
(loss, ntoks), grads = loss_and_grad_fn(model, X, y, loss_mask)
|
| 531 |
+
optimizer.update(model, grads)
|
| 532 |
+
mx.eval(loss, ntoks, model, optimizer)
|
| 533 |
+
losses.append(loss.item())
|
| 534 |
+
trained_toks += ntoks.item()
|
| 535 |
+
step += 1
|
| 536 |
+
if (step % (num_steps // 30) == 0):
|
| 537 |
+
avg_train_loss = np.mean(losses)
|
| 538 |
+
lr = optimizer.learning_rate.item()
|
| 539 |
+
val_loss = evaluate(model=model, data=eval_data, tokenizer=tokenizer, seq_length=seq_length)
|
| 540 |
+
print(f"{avg_train_loss:8.4f} ({val_loss:6.4f}) @ {step//(num_steps//30):2}/30 w/ {lr:.2e} ({time.perf_counter() - tic:.2f} sec)")
|
| 541 |
+
metrics['val_losses'].append(val_loss)
|
| 542 |
+
# print(f"{avg_train_loss:8.4f} @ {step//(num_steps//30):2}/30 w/ {lr:.2e} ({time.perf_counter() - tic:.2f} sec)")
|
| 543 |
+
tic = time.perf_counter()
|
| 544 |
+
metrics['steps'].append(step)
|
| 545 |
+
metrics['learning_rates'].append(lr)
|
| 546 |
+
metrics['all_train_losses'].extend(losses)
|
| 547 |
+
metrics['avg_train_losses'].append(avg_train_loss)
|
| 548 |
+
metrics['trained_toks'].append(trained_toks)
|
| 549 |
+
losses = []
|
| 550 |
+
trained_toks = 0
|
| 551 |
+
_path = f'trained_retnphi.safetensors' if model_cfg['use_retention'] else f'trained_orgnphi.safetensors'
|
| 552 |
+
mx.save_safetensors(_path, dict(tree_flatten(model.trainable_parameters())))
|
| 553 |
+
log = {
|
| 554 |
+
'args': {
|
| 555 |
+
'learning_rates': learning_rates,
|
| 556 |
+
'num_epochs': num_epochs,
|
| 557 |
+
'batch_size': batch_size,
|
| 558 |
+
'seq_length': seq_length,
|
| 559 |
+
'lora_cfg': lora_cfg,
|
| 560 |
+
'model_cfg': model_cfg,
|
| 561 |
+
'thaws': thaws,
|
| 562 |
+
'from_path': from_path
|
| 563 |
+
},
|
| 564 |
+
'metrics': metrics
|
| 565 |
+
}
|
| 566 |
+
with open(f'train_log_{timestamp}.json', 'w') as f:
|
| 567 |
+
json.dump(log, f, indent=2)
|
| 568 |
+
del model
|
| 569 |
+
|
| 570 |
+
tokenizer = Tokenizer()
|
| 571 |
+
|
| 572 |
+
def main(take=1000, layers='all', targets=["self_attn.o_proj"], thaws=['new', 'post_attention_layernorm'], rank=32, scale=0.1, dropout=0.0, lr_max=1e-4, lr_min=1e-5, num_epochs=90, batch_size=1, seq_length=256, vocab_size=256, use_retention=True, untie_embedding=True, prompt='There are 8 candies in a carton. How many candies will be in 5 cartons?'):
|
| 573 |
+
lora_cfg = dict(layers=layers, targets=targets, rank=rank, scale=scale, dropout=dropout)
|
| 574 |
+
model_cfg = dict(vocab_size=vocab_size, use_retention=use_retention, untie_embedding=untie_embedding)
|
| 575 |
+
train_gsm(learning_rates=(lr_max, lr_min), num_epochs=num_epochs, batch_size=batch_size, seq_length=seq_length, lora_cfg=lora_cfg, model_cfg=model_cfg, thaws=thaws, take=take)
|
| 576 |
+
generate(prompt=prompt, lora_cfg=lora_cfg, model_cfg=model_cfg, max_tokens=(seq_length-len(prompt)))
|
| 577 |
+
|
| 578 |
+
if __name__ == "__main__":
|
| 579 |
+
main(take=None, num_epochs=3) # -> 240916
|
| 580 |
+
main(take=None, num_epochs=3, untie_embedding=False)
|
| 581 |
+
|
| 582 |
+
main(take=None, num_epochs=3, use_retention=False)
|
| 583 |
+
main(take=None, num_epochs=3, untie_embedding=False, use_retention=False)
|
| 584 |
+
# fire.Fire(main)
|
| 585 |
+
|
| 586 |
+
# Output:
|
| 587 |
+
# 388.7268 @ 1/30 w/ 3.36e-05 (64.73 sec)
|
| 588 |
+
# ...
|
| 589 |
+
# 4.3768 @ 30/30 w/ 1.00e-05 (64.36 sec)
|
| 590 |
+
# prompt='There are 8 candies in a carton. How many candies will be in 5 cartons?' + output='\ncandies_in_carton = 8 \nnumber_of_cartons = 5\ntotal_no_of_candies = candies_in_carton * number_of_cartons\n\n'
|
| 591 |
+
# -> There are 8 candies in a carton. How many candies will be in 5 cartons?
|
| 592 |
+
# candies_in_carton = 8
|
| 593 |
+
# number_of_cartons = 5
|
| 594 |
+
# total_no_of_candies = candies_in_carton * number_of_cartons
|