""" ScaledOptum — Pure-integer optimizer for BigInt-correlation ternary training. For each ternary module: 1. Calls module.update_corr() — pure-integer correlation accumulation score = Σ (grad_sign × T) per group (int16) corr_accum += score (int64, BigInt, never resets or clips) 2. The CARRY step happens via the S computation in forward: S = 2^E × (1 + corr_accum / (step × gs)) The corr_accum / (step × gs) is the continuous adjustment. 3. E is never manually updated — the corr_accum BigInt provides the continuous gradient-driven adjustment to S. If desired, E can be slowly tracked toward the corr-derived S for better initialization at inference. """ import torch from torch.optim import Optimizer class ScaledOptum(Optimizer): """ Pure-integer optimizer for ternary training with BigInt correlation. Calls update_corr() on each ternary module — no float state. """ def __init__(self, params, lr=0.3, default_group_size=32): defaults = dict(lr=lr, default_group_size=default_group_size) super().__init__(params, defaults) @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: if 'ternary_modules' in group: for mod in group['ternary_modules']: mod.update_corr() for p in group['params']: if p.grad is None: continue grad = p.grad if grad.is_sparse: grad = grad.to_dense() p.add_(-group['lr'] * grad.sign()) return loss def add_ternary_modules(self, modules): if not self.param_groups: self.param_groups.append({'params': [], 'ternary_modules': [], 'lr': 0.3, 'default_group_size': 32}) for group in self.param_groups: if 'ternary_modules' not in group: group['ternary_modules'] = [] for mod in modules: if mod not in group['ternary_modules']: group['ternary_modules'].append(mod)