| """ |
| ScaledOptum — Pure-integer optimizer for BigInt-correlation ternary training. |
| |
| For each ternary module: |
| 1. Calls module.update_corr() — pure-integer correlation accumulation |
| score = Σ (grad_sign × T) per group (int16) |
| corr_accum += score (int64, BigInt, never resets or clips) |
| |
| 2. The CARRY step happens via the S computation in forward: |
| S = 2^E × (1 + corr_accum / (step × gs)) |
| The corr_accum / (step × gs) is the continuous adjustment. |
| |
| 3. E is never manually updated — the corr_accum BigInt provides |
| the continuous gradient-driven adjustment to S. |
| If desired, E can be slowly tracked toward the corr-derived S |
| for better initialization at inference. |
| """ |
| import torch |
| from torch.optim import Optimizer |
|
|
|
|
| class ScaledOptum(Optimizer): |
| """ |
| Pure-integer optimizer for ternary training with BigInt correlation. |
| |
| Calls update_corr() on each ternary module — no float state. |
| """ |
|
|
| def __init__(self, params, lr=0.3, default_group_size=32): |
| defaults = dict(lr=lr, default_group_size=default_group_size) |
| super().__init__(params, defaults) |
|
|
| @torch.no_grad() |
| def step(self, closure=None): |
| loss = None |
| if closure is not None: |
| with torch.enable_grad(): |
| loss = closure() |
|
|
| for group in self.param_groups: |
| if 'ternary_modules' in group: |
| for mod in group['ternary_modules']: |
| mod.update_corr() |
|
|
| for p in group['params']: |
| if p.grad is None: |
| continue |
| grad = p.grad |
| if grad.is_sparse: |
| grad = grad.to_dense() |
| p.add_(-group['lr'] * grad.sign()) |
|
|
| return loss |
|
|
| def add_ternary_modules(self, modules): |
| if not self.param_groups: |
| self.param_groups.append({'params': [], 'ternary_modules': [], |
| 'lr': 0.3, 'default_group_size': 32}) |
| for group in self.param_groups: |
| if 'ternary_modules' not in group: |
| group['ternary_modules'] = [] |
| for mod in modules: |
| if mod not in group['ternary_modules']: |
| group['ternary_modules'].append(mod) |
|
|