File size: 2,280 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
ScaledOptum — Pure-integer optimizer for BigInt-correlation ternary training.

For each ternary module:
  1. Calls module.update_corr() — pure-integer correlation accumulation
     score = Σ (grad_sign × T) per group  (int16)
     corr_accum += score  (int64, BigInt, never resets or clips)
  
  2. The CARRY step happens via the S computation in forward:
     S = 2^E × (1 + corr_accum / (step × gs))
     The corr_accum / (step × gs) is the continuous adjustment.
     
  3. E is never manually updated — the corr_accum BigInt provides
     the continuous gradient-driven adjustment to S.
     If desired, E can be slowly tracked toward the corr-derived S
     for better initialization at inference.
"""
import torch
from torch.optim import Optimizer


class ScaledOptum(Optimizer):
    """
    Pure-integer optimizer for ternary training with BigInt correlation.

    Calls update_corr() on each ternary module — no float state.
    """

    def __init__(self, params, lr=0.3, default_group_size=32):
        defaults = dict(lr=lr, default_group_size=default_group_size)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'ternary_modules' in group:
                for mod in group['ternary_modules']:
                    mod.update_corr()

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad
                if grad.is_sparse:
                    grad = grad.to_dense()
                p.add_(-group['lr'] * grad.sign())

        return loss

    def add_ternary_modules(self, modules):
        if not self.param_groups:
            self.param_groups.append({'params': [], 'ternary_modules': [],
                                      'lr': 0.3, 'default_group_size': 32})
        for group in self.param_groups:
            if 'ternary_modules' not in group:
                group['ternary_modules'] = []
            for mod in modules:
                if mod not in group['ternary_modules']:
                    group['ternary_modules'].append(mod)