File size: 2,280 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | """
ScaledOptum — Pure-integer optimizer for BigInt-correlation ternary training.
For each ternary module:
1. Calls module.update_corr() — pure-integer correlation accumulation
score = Σ (grad_sign × T) per group (int16)
corr_accum += score (int64, BigInt, never resets or clips)
2. The CARRY step happens via the S computation in forward:
S = 2^E × (1 + corr_accum / (step × gs))
The corr_accum / (step × gs) is the continuous adjustment.
3. E is never manually updated — the corr_accum BigInt provides
the continuous gradient-driven adjustment to S.
If desired, E can be slowly tracked toward the corr-derived S
for better initialization at inference.
"""
import torch
from torch.optim import Optimizer
class ScaledOptum(Optimizer):
"""
Pure-integer optimizer for ternary training with BigInt correlation.
Calls update_corr() on each ternary module — no float state.
"""
def __init__(self, params, lr=0.3, default_group_size=32):
defaults = dict(lr=lr, default_group_size=default_group_size)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
if 'ternary_modules' in group:
for mod in group['ternary_modules']:
mod.update_corr()
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
grad = grad.to_dense()
p.add_(-group['lr'] * grad.sign())
return loss
def add_ternary_modules(self, modules):
if not self.param_groups:
self.param_groups.append({'params': [], 'ternary_modules': [],
'lr': 0.3, 'default_group_size': 32})
for group in self.param_groups:
if 'ternary_modules' not in group:
group['ternary_modules'] = []
for mod in modules:
if mod not in group['ternary_modules']:
group['ternary_modules'].append(mod)
|