File size: 5,060 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
SignGSD — Sign Gradient-Sign Descent optimizer.

A minimal optimizer for low-precision (ternary/binary) training.
Key property: discards all magnitude information. Only signs matter.
This aligns with ternary weight domains where weights are {-1, 0, +1}
and updates are discrete flips rather than continuous steps.

Memory: zero optimizer state (no momentum buffers). Only stores what
torch already tracks (params + grad). 0 bytes overhead vs AdamW's
8 bytes/param (2× float32).
"""
import torch
from torch.optim import Optimizer


class ScaledOptum(Optimizer):
    """
    Sign Gradient-Sign Descent.

    Update rule:
        p += -lr * (sign(grad) + wd * sign(p))

    Compared to AdamW:
      - No first/second moment estimates (no exp_avg, exp_avg_sq)
      - No adaptive per-parameter learning rate
      - Weight decay acts on sign(p) not p itself
      - Uniform LR across all parameters

    Why this works for ternary training:
      Ternary weights live in {-1, 0, +1}. Continuous updates like
      p -= lr * grad immediately leave the ternary domain. SignGSD
      sidesteps this by only voting on direction — the actual flip
      decision (±1 vote, not a continuous step) can be accumulated
      elsewhere (e.g., T_accum counts sign votes and flips at threshold).
    """

    def __init__(self, params, lr=1e-2, weight_decay=0.0):
        """
        Args:
            params: iterable of parameters or param groups.
            lr: uniform learning rate (same for all params, no adaptive scaling).
            weight_decay: L2-style decay, but applied as wd * sign(p), not wd * p.
                          This pushes ternary weights toward zero when
                          sign(grad) == sign(p), because the update becomes
                          sign(grad) + sign(p) = ±2 (stronger push) or 0 (cancel)
                          when signs disagree.
        """
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        """
        Perform a single optimization step.

        Flow:
          1. Compute grad.sign() — direction of steepest descent, ±1 per element.
             Discards all magnitude. This is the core difference from AdamW
             which uses grad magnitude via adaptive RMS scaling.
          2. Optionally add wd * p.sign() — weight decay using _sign_ of weight,
             not the weight itself. In standard weight decay (wd * p), large
             weights are regularized more. Here, all nonzero weights (±1 in
             ternary) receive equal regularization regardless of magnitude.
          3. p += -lr * update — apply the sign-based step.

        Memory: Does NOT allocate any optimizer state. The gradient sign and
        parameter sign are computed on-the-fly from existing .grad and .data.

        Returns:
            loss from closure if provided.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            wd = group["weight_decay"]

            for p in group["params"]:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    grad = grad.to_dense()

                # === Core: sign-sign update ===
                # update = sign(grad) ∈ {-1, 0, +1}
                # Zero gradient → zero update (no flip vote)
                update = grad.sign()

                if wd > 0:
                    # Weight decay as sign(p) not p.
                    # For ternary p ∈ {-1, 0, +1}, sign(p) = p (except 0).
                    # This biases toward zero: when grad and p agree,
                    # |update| = 2 (stronger pull back toward zero).
                    # When they disagree, they cancel to 0 (no update).
                    update = update + wd * p.sign()

                # p += -lr * update
                # For ternary: the actual flip happens elsewhere.
                # This step writes to the _latent_ or _accumulator_ values,
                # not the ternary weights themselves.
                # (See prepare_ternary_backward + _ternary_update_memory
                #  in the ARBS training loop for the flip pipeline.)
                p.add_(-lr * update)

        return loss

    @torch.no_grad()
    def get_memory_mb(self, params=None) -> float:
        """
        Compute total memory of given parameters in MB.

        Unlike AdamW which needs 8 bytes/param for state (2× float32),
        SignGSD stores zero optimizer state. The memory reported here
        is just the parameter tensors themselves.
        """
        if params is None:
            params = []
            for group in self.param_groups:
                params.extend(group["params"])
        total_bytes = sum(p.numel() * p.element_size() for p in params)
        return total_bytes / (1024 * 1024)