Replace cpu_offload constructor param with turn_on/turn_off API (#26)
Browse files* Replace cpu_offload constructor param with turn_on/turn_off API [skip-build]
Remove the cpu_offload boolean from Muon.__init__ and add explicit
turn_on_cpu_offload() / turn_off_cpu_offload() methods instead.
state_dict and load_state_dict now require offload to be disabled
first (RuntimeError). Preserves AdamW tensor cache invalidation
on load_state_dict. Adds test_toggle_correctness test.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* fix: skip empty FSDP shards in CPU offload and add storage validation
Empty FSDP shards (storage size=0) were being tracked by CPUOffloadPool,
causing double-free errors on resize_(0) during offload. This led to
hangs on ranks 1-7 while rank 0 succeeded.
- Skip tensors with empty storage in track() to avoid double-free
- Add storage size validation in offload() and reload() with RuntimeError
- Add logging for turn_on/turn_off_cpu_offload
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add built binary [skip-build]
---------
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
- build/torch210-cxx11-cu126-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-cu126-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
- build/torch210-cxx11-cu126-x86_64-linux/cpu_offload.py +31 -10
- build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py +16 -12
- build/torch210-cxx11-cu126-x86_64-linux/muon.py +44 -48
- build/torch210-cxx11-cu128-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-cu128-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
- build/torch210-cxx11-cu128-x86_64-linux/cpu_offload.py +31 -10
- build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py +16 -12
- build/torch210-cxx11-cu128-x86_64-linux/muon.py +44 -48
- build/torch210-cxx11-cu130-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-cu130-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
- build/torch210-cxx11-cu130-x86_64-linux/cpu_offload.py +31 -10
- build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py +16 -12
- build/torch210-cxx11-cu130-x86_64-linux/muon.py +44 -48
- build/torch210-cxx11-rocm70-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-rocm70-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
- build/torch210-cxx11-rocm70-x86_64-linux/cpu_offload.py +31 -10
- build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py +16 -12
- build/torch210-cxx11-rocm70-x86_64-linux/muon.py +44 -48
- build/torch210-cxx11-rocm71-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-rocm71-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
- build/torch210-cxx11-rocm71-x86_64-linux/cpu_offload.py +31 -10
- build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py +16 -12
- build/torch210-cxx11-rocm71-x86_64-linux/muon.py +44 -48
- build/torch28-cxx11-cu126-x86_64-linux/_ops.py +3 -3
- build/torch28-cxx11-cu126-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu126-x86_64-linux/cpu_offload.py +31 -10
- build/torch28-cxx11-cu126-x86_64-linux/distributed/utils.py +16 -12
- build/torch28-cxx11-cu126-x86_64-linux/muon.py +44 -48
- build/torch28-cxx11-cu128-x86_64-linux/_ops.py +3 -3
- build/torch28-cxx11-cu128-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu128-x86_64-linux/cpu_offload.py +31 -10
- build/torch28-cxx11-cu128-x86_64-linux/distributed/utils.py +16 -12
- build/torch28-cxx11-cu128-x86_64-linux/muon.py +44 -48
- build/torch28-cxx11-cu129-x86_64-linux/_ops.py +3 -3
- build/torch28-cxx11-cu129-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu129-x86_64-linux/cpu_offload.py +31 -10
- build/torch28-cxx11-cu129-x86_64-linux/distributed/utils.py +16 -12
- build/torch28-cxx11-cu129-x86_64-linux/muon.py +44 -48
- build/torch28-cxx11-rocm63-x86_64-linux/_ops.py +3 -3
- build/torch28-cxx11-rocm63-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
- build/torch28-cxx11-rocm63-x86_64-linux/cpu_offload.py +31 -10
- build/torch28-cxx11-rocm63-x86_64-linux/distributed/utils.py +16 -12
- build/torch28-cxx11-rocm63-x86_64-linux/muon.py +44 -48
- build/torch28-cxx11-rocm64-x86_64-linux/_ops.py +3 -3
- build/torch28-cxx11-rocm64-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
- build/torch28-cxx11-rocm64-x86_64-linux/cpu_offload.py +31 -10
- build/torch28-cxx11-rocm64-x86_64-linux/distributed/utils.py +16 -12
- build/torch28-cxx11-rocm64-x86_64-linux/muon.py +44 -48
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_b68ea5b_dirty
|
| 3 |
+
ops = torch.ops._optimizer_b68ea5b_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_b68ea5b_dirty::{op_name}"
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1940944
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7be82307f66be4bb841072ecdb3d105dc73bc9ee9ca21b1ce33bddc24113f4d1
|
| 3 |
size 1940944
|
|
@@ -68,7 +68,11 @@ class CPUOffloadPool:
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
self._managed.append(tensor)
|
| 73 |
|
| 74 |
# ------------------------------------------------------------------
|
|
@@ -89,7 +93,10 @@ class CPUOffloadPool:
|
|
| 89 |
indices.append(idx)
|
| 90 |
offsets.append((off, n))
|
| 91 |
off += n
|
| 92 |
-
cpu_flat = torch.empty(off,
|
|
|
|
|
|
|
|
|
|
| 93 |
self._groups[dtype] = {
|
| 94 |
"indices": indices,
|
| 95 |
"offsets": offsets,
|
|
@@ -117,8 +124,7 @@ class CPUOffloadPool:
|
|
| 117 |
self._ensure_stream()
|
| 118 |
|
| 119 |
# Offload stream waits for compute to finish.
|
| 120 |
-
compute_event = torch.cuda.current_stream(
|
| 121 |
-
self._device).record_event()
|
| 122 |
self._offload_stream.wait_event(compute_event)
|
| 123 |
|
| 124 |
offloaded_bytes = 0
|
|
@@ -134,15 +140,23 @@ class CPUOffloadPool:
|
|
| 134 |
for i, mgd_idx in enumerate(indices):
|
| 135 |
local = self._local(self._managed[mgd_idx])
|
| 136 |
off, n = offsets[i]
|
| 137 |
-
cpu_flat[off:off + n].copy_(
|
| 138 |
-
|
| 139 |
|
| 140 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 141 |
|
| 142 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 143 |
self._offload_stream.synchronize()
|
| 144 |
for t in self._managed:
|
| 145 |
-
self._local(t).untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
if not self._logged:
|
| 148 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
@@ -165,7 +179,14 @@ class CPUOffloadPool:
|
|
| 165 |
# Re-allocate all GPU storages first.
|
| 166 |
for t in self._managed:
|
| 167 |
local = self._local(t)
|
| 168 |
-
local.untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 171 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
@@ -177,8 +198,8 @@ class CPUOffloadPool:
|
|
| 177 |
for i, mgd_idx in enumerate(indices):
|
| 178 |
local = self._local(self._managed[mgd_idx])
|
| 179 |
off, n = offsets[i]
|
| 180 |
-
local.reshape(-1).copy_(
|
| 181 |
-
|
| 182 |
|
| 183 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 184 |
|
|
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
+
storage = local.untyped_storage()
|
| 72 |
+
# Skip tensors with empty storage (e.g. empty FSDP shards)
|
| 73 |
+
if storage.size() == 0:
|
| 74 |
+
return
|
| 75 |
+
self._storage_nbytes[tid] = storage.size()
|
| 76 |
self._managed.append(tensor)
|
| 77 |
|
| 78 |
# ------------------------------------------------------------------
|
|
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
+
cpu_flat = torch.empty(off,
|
| 97 |
+
dtype=dtype,
|
| 98 |
+
device="cpu",
|
| 99 |
+
pin_memory=True)
|
| 100 |
self._groups[dtype] = {
|
| 101 |
"indices": indices,
|
| 102 |
"offsets": offsets,
|
|
|
|
| 124 |
self._ensure_stream()
|
| 125 |
|
| 126 |
# Offload stream waits for compute to finish.
|
| 127 |
+
compute_event = torch.cuda.current_stream(self._device).record_event()
|
|
|
|
| 128 |
self._offload_stream.wait_event(compute_event)
|
| 129 |
|
| 130 |
offloaded_bytes = 0
|
|
|
|
| 140 |
for i, mgd_idx in enumerate(indices):
|
| 141 |
local = self._local(self._managed[mgd_idx])
|
| 142 |
off, n = offsets[i]
|
| 143 |
+
cpu_flat[off:off + n].copy_(local.reshape(-1),
|
| 144 |
+
non_blocking=True)
|
| 145 |
|
| 146 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 147 |
|
| 148 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 149 |
self._offload_stream.synchronize()
|
| 150 |
for t in self._managed:
|
| 151 |
+
storage = self._local(t).untyped_storage()
|
| 152 |
+
if storage.size() != 0:
|
| 153 |
+
storage.resize_(0)
|
| 154 |
+
else:
|
| 155 |
+
raise RuntimeError(
|
| 156 |
+
f"Tensor storage is already freed (size=0) before offload. "
|
| 157 |
+
f"This indicates a double-free or external interference. "
|
| 158 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 159 |
+
)
|
| 160 |
|
| 161 |
if not self._logged:
|
| 162 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
|
|
| 179 |
# Re-allocate all GPU storages first.
|
| 180 |
for t in self._managed:
|
| 181 |
local = self._local(t)
|
| 182 |
+
storage = local.untyped_storage()
|
| 183 |
+
if storage.size() != 0:
|
| 184 |
+
raise RuntimeError(
|
| 185 |
+
f"Storage should have been freed (size=0) before reload, "
|
| 186 |
+
f"but got size={storage.size()}. "
|
| 187 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 188 |
+
)
|
| 189 |
+
storage.resize_(self._storage_nbytes[id(t)])
|
| 190 |
|
| 191 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 192 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
|
|
| 198 |
for i, mgd_idx in enumerate(indices):
|
| 199 |
local = self._local(self._managed[mgd_idx])
|
| 200 |
off, n = offsets[i]
|
| 201 |
+
local.reshape(-1).copy_(cpu_flat[off:off + n],
|
| 202 |
+
non_blocking=True)
|
| 203 |
|
| 204 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 205 |
|
|
@@ -163,9 +163,10 @@ def construct_shard_mesh(
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
-
#
|
| 167 |
-
#
|
| 168 |
-
#
|
|
|
|
| 169 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 170 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 171 |
if key not in _ranks_to_dist_cache:
|
|
@@ -207,22 +208,25 @@ def construct_shard_mesh(
|
|
| 207 |
assert len(shard_placements) == len(set(shard_placements))
|
| 208 |
|
| 209 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 210 |
-
#
|
| 211 |
-
#
|
|
|
|
|
|
|
| 212 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 213 |
return (*t.shape, *t.flatten().tolist())
|
| 214 |
|
| 215 |
my_key = None
|
| 216 |
for sm in shard_meshes:
|
| 217 |
-
key = _cache_key(sm)
|
| 218 |
if (my_rank == sm).any().item():
|
|
|
|
| 219 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 220 |
my_key = key
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
| 227 |
|
| 228 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
+
# Reuses the mesh's existing ProcessGroup directly, avoiding the
|
| 167 |
+
# overhead of dist.new_group(). The standard path below also handles
|
| 168 |
+
# subset calls safely via use_local_synchronization=True, but this
|
| 169 |
+
# fast path is still beneficial for the common 1D shard case.
|
| 170 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 171 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 172 |
if key not in _ranks_to_dist_cache:
|
|
|
|
| 208 |
assert len(shard_placements) == len(set(shard_placements))
|
| 209 |
|
| 210 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 211 |
+
# Each rank only creates the group it belongs to, using
|
| 212 |
+
# use_local_synchronization=True so that only group members need to
|
| 213 |
+
# coordinate. This avoids deadlocks when different PP stages call
|
| 214 |
+
# construct_shard_mesh for different parameters.
|
| 215 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 216 |
return (*t.shape, *t.flatten().tolist())
|
| 217 |
|
| 218 |
my_key = None
|
| 219 |
for sm in shard_meshes:
|
|
|
|
| 220 |
if (my_rank == sm).any().item():
|
| 221 |
+
key = _cache_key(sm)
|
| 222 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 223 |
my_key = key
|
| 224 |
+
if key not in _ranks_to_dist_cache:
|
| 225 |
+
pg = dist.new_group(sm.flatten().tolist(),
|
| 226 |
+
use_local_synchronization=True)
|
| 227 |
+
_ranks_to_dist_cache[key] = (
|
| 228 |
+
DeviceMesh(device_type="cuda", mesh=sm),
|
| 229 |
+
pg,
|
| 230 |
+
)
|
| 231 |
|
| 232 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
@@ -8,7 +8,7 @@ import torch.distributed as dist
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
-
from .adamw import step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
-
expert_keys=None
|
| 211 |
-
cpu_offload=False):
|
| 212 |
defaults = dict(
|
| 213 |
lr=lr,
|
| 214 |
weight_decay=weight_decay,
|
|
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 229 |
if param_group.get("use_muon", None) is None:
|
| 230 |
raise ValueError(
|
| 231 |
error_message.format(idx=_idx) + instruction_code)
|
| 232 |
-
|
| 233 |
super().__init__(params, defaults)
|
| 234 |
|
| 235 |
self.debug = debug
|
|
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 243 |
self.chunk_size = chunk_size
|
| 244 |
self.use_distributed_muon = use_distributed_muon
|
| 245 |
self.expert_keys = expert_keys
|
| 246 |
-
self.cpu_offload =
|
| 247 |
-
self._cpu_offload_pool
|
| 248 |
self._offload_initialized = False
|
| 249 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 250 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 1008 |
# D2H: offload optimizer states to CPU after computation.
|
| 1009 |
if self.cpu_offload:
|
| 1010 |
if not self._offload_initialized:
|
|
|
|
|
|
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
-
#
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
-
def
|
| 1022 |
-
"""
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1032 |
self._cpu_offload_pool.reload()
|
| 1033 |
torch.cuda.current_stream().synchronize()
|
| 1034 |
-
|
| 1035 |
-
|
| 1036 |
-
|
| 1037 |
-
# (which frees GPU storage on the originals via resize_(0)).
|
| 1038 |
-
for k in sd["state"]:
|
| 1039 |
-
sd["state"][k] = {
|
| 1040 |
-
sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
|
| 1041 |
-
for sk, sv in sd["state"][k].items()
|
| 1042 |
-
}
|
| 1043 |
-
self._cpu_offload_pool.offload()
|
| 1044 |
-
return sd
|
| 1045 |
|
| 1046 |
-
|
| 1047 |
-
|
|
|
|
| 1048 |
|
| 1049 |
-
|
| 1050 |
-
|
| 1051 |
-
|
| 1052 |
-
|
| 1053 |
-
# If states were offloaded, reload first so storage sizes are
|
| 1054 |
-
# correct for super().load_state_dict() to overwrite.
|
| 1055 |
-
if self.cpu_offload and self._offload_initialized:
|
| 1056 |
-
self._cpu_offload_pool.reload()
|
| 1057 |
-
torch.cuda.current_stream().synchronize()
|
| 1058 |
|
|
|
|
|
|
|
|
|
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
-
|
| 1064 |
-
|
| 1065 |
-
self._offload_initialized = False
|
| 1066 |
-
self._register_states_for_offload()
|
| 1067 |
-
self._offload_initialized = True
|
| 1068 |
-
self._cpu_offload_pool.offload()
|
|
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
+
from .adamw import _placement_cache, _tensor_cache, step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
+
expert_keys=None):
|
|
|
|
| 211 |
defaults = dict(
|
| 212 |
lr=lr,
|
| 213 |
weight_decay=weight_decay,
|
|
|
|
| 228 |
if param_group.get("use_muon", None) is None:
|
| 229 |
raise ValueError(
|
| 230 |
error_message.format(idx=_idx) + instruction_code)
|
|
|
|
| 231 |
super().__init__(params, defaults)
|
| 232 |
|
| 233 |
self.debug = debug
|
|
|
|
| 241 |
self.chunk_size = chunk_size
|
| 242 |
self.use_distributed_muon = use_distributed_muon
|
| 243 |
self.expert_keys = expert_keys
|
| 244 |
+
self.cpu_offload = False
|
| 245 |
+
self._cpu_offload_pool: CPUOffloadPool | None = None
|
| 246 |
self._offload_initialized = False
|
| 247 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 248 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
|
|
| 1006 |
# D2H: offload optimizer states to CPU after computation.
|
| 1007 |
if self.cpu_offload:
|
| 1008 |
if not self._offload_initialized:
|
| 1009 |
+
if self._cpu_offload_pool is None:
|
| 1010 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
+
# CPU offload public helpers
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
+
def turn_on_cpu_offload(self):
|
| 1022 |
+
"""Enable CPU offload for optimizer states."""
|
| 1023 |
+
if self.cpu_offload:
|
| 1024 |
+
return
|
| 1025 |
+
logger.info("[Muon] turn_on_cpu_offload")
|
| 1026 |
+
self.cpu_offload = True
|
| 1027 |
+
if not self.state:
|
| 1028 |
+
return
|
| 1029 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1030 |
+
self._offload_initialized = False
|
| 1031 |
+
self._register_states_for_offload()
|
| 1032 |
+
self._offload_initialized = True
|
| 1033 |
+
self._cpu_offload_pool.offload()
|
| 1034 |
+
|
| 1035 |
+
def turn_off_cpu_offload(self):
|
| 1036 |
+
"""Disable CPU offload and keep optimizer states resident on GPU."""
|
| 1037 |
+
if not self.cpu_offload:
|
| 1038 |
+
return
|
| 1039 |
+
logger.info("[Muon] turn_off_cpu_offload")
|
| 1040 |
+
if self._offload_initialized:
|
| 1041 |
self._cpu_offload_pool.reload()
|
| 1042 |
torch.cuda.current_stream().synchronize()
|
| 1043 |
+
self._cpu_offload_pool = None
|
| 1044 |
+
self._offload_initialized = False
|
| 1045 |
+
self.cpu_offload = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1046 |
|
| 1047 |
+
# ------------------------------------------------------------------
|
| 1048 |
+
# Checkpoint support for cpu_offload
|
| 1049 |
+
# ------------------------------------------------------------------
|
| 1050 |
|
| 1051 |
+
def state_dict(self) -> dict:
|
| 1052 |
+
if self.cpu_offload:
|
| 1053 |
+
raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
|
| 1054 |
+
return super().state_dict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1055 |
|
| 1056 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
| 1057 |
+
if self.cpu_offload:
|
| 1058 |
+
raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
+
# Invalidate adamw.py's module-level tensor caches so that
|
| 1062 |
+
# the next step rebuilds them with the newly loaded state tensors.
|
| 1063 |
+
_placement_cache.clear()
|
| 1064 |
+
_tensor_cache.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_b68ea5b_dirty
|
| 3 |
+
ops = torch.ops._optimizer_b68ea5b_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_b68ea5b_dirty::{op_name}"
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 2004144
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7fa10afe7c505f69ccf7a98aca116b6d551b9577ecce2dab2559c6c3b433be20
|
| 3 |
size 2004144
|
|
@@ -68,7 +68,11 @@ class CPUOffloadPool:
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
self._managed.append(tensor)
|
| 73 |
|
| 74 |
# ------------------------------------------------------------------
|
|
@@ -89,7 +93,10 @@ class CPUOffloadPool:
|
|
| 89 |
indices.append(idx)
|
| 90 |
offsets.append((off, n))
|
| 91 |
off += n
|
| 92 |
-
cpu_flat = torch.empty(off,
|
|
|
|
|
|
|
|
|
|
| 93 |
self._groups[dtype] = {
|
| 94 |
"indices": indices,
|
| 95 |
"offsets": offsets,
|
|
@@ -117,8 +124,7 @@ class CPUOffloadPool:
|
|
| 117 |
self._ensure_stream()
|
| 118 |
|
| 119 |
# Offload stream waits for compute to finish.
|
| 120 |
-
compute_event = torch.cuda.current_stream(
|
| 121 |
-
self._device).record_event()
|
| 122 |
self._offload_stream.wait_event(compute_event)
|
| 123 |
|
| 124 |
offloaded_bytes = 0
|
|
@@ -134,15 +140,23 @@ class CPUOffloadPool:
|
|
| 134 |
for i, mgd_idx in enumerate(indices):
|
| 135 |
local = self._local(self._managed[mgd_idx])
|
| 136 |
off, n = offsets[i]
|
| 137 |
-
cpu_flat[off:off + n].copy_(
|
| 138 |
-
|
| 139 |
|
| 140 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 141 |
|
| 142 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 143 |
self._offload_stream.synchronize()
|
| 144 |
for t in self._managed:
|
| 145 |
-
self._local(t).untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
if not self._logged:
|
| 148 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
@@ -165,7 +179,14 @@ class CPUOffloadPool:
|
|
| 165 |
# Re-allocate all GPU storages first.
|
| 166 |
for t in self._managed:
|
| 167 |
local = self._local(t)
|
| 168 |
-
local.untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 171 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
@@ -177,8 +198,8 @@ class CPUOffloadPool:
|
|
| 177 |
for i, mgd_idx in enumerate(indices):
|
| 178 |
local = self._local(self._managed[mgd_idx])
|
| 179 |
off, n = offsets[i]
|
| 180 |
-
local.reshape(-1).copy_(
|
| 181 |
-
|
| 182 |
|
| 183 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 184 |
|
|
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
+
storage = local.untyped_storage()
|
| 72 |
+
# Skip tensors with empty storage (e.g. empty FSDP shards)
|
| 73 |
+
if storage.size() == 0:
|
| 74 |
+
return
|
| 75 |
+
self._storage_nbytes[tid] = storage.size()
|
| 76 |
self._managed.append(tensor)
|
| 77 |
|
| 78 |
# ------------------------------------------------------------------
|
|
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
+
cpu_flat = torch.empty(off,
|
| 97 |
+
dtype=dtype,
|
| 98 |
+
device="cpu",
|
| 99 |
+
pin_memory=True)
|
| 100 |
self._groups[dtype] = {
|
| 101 |
"indices": indices,
|
| 102 |
"offsets": offsets,
|
|
|
|
| 124 |
self._ensure_stream()
|
| 125 |
|
| 126 |
# Offload stream waits for compute to finish.
|
| 127 |
+
compute_event = torch.cuda.current_stream(self._device).record_event()
|
|
|
|
| 128 |
self._offload_stream.wait_event(compute_event)
|
| 129 |
|
| 130 |
offloaded_bytes = 0
|
|
|
|
| 140 |
for i, mgd_idx in enumerate(indices):
|
| 141 |
local = self._local(self._managed[mgd_idx])
|
| 142 |
off, n = offsets[i]
|
| 143 |
+
cpu_flat[off:off + n].copy_(local.reshape(-1),
|
| 144 |
+
non_blocking=True)
|
| 145 |
|
| 146 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 147 |
|
| 148 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 149 |
self._offload_stream.synchronize()
|
| 150 |
for t in self._managed:
|
| 151 |
+
storage = self._local(t).untyped_storage()
|
| 152 |
+
if storage.size() != 0:
|
| 153 |
+
storage.resize_(0)
|
| 154 |
+
else:
|
| 155 |
+
raise RuntimeError(
|
| 156 |
+
f"Tensor storage is already freed (size=0) before offload. "
|
| 157 |
+
f"This indicates a double-free or external interference. "
|
| 158 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 159 |
+
)
|
| 160 |
|
| 161 |
if not self._logged:
|
| 162 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
|
|
| 179 |
# Re-allocate all GPU storages first.
|
| 180 |
for t in self._managed:
|
| 181 |
local = self._local(t)
|
| 182 |
+
storage = local.untyped_storage()
|
| 183 |
+
if storage.size() != 0:
|
| 184 |
+
raise RuntimeError(
|
| 185 |
+
f"Storage should have been freed (size=0) before reload, "
|
| 186 |
+
f"but got size={storage.size()}. "
|
| 187 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 188 |
+
)
|
| 189 |
+
storage.resize_(self._storage_nbytes[id(t)])
|
| 190 |
|
| 191 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 192 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
|
|
| 198 |
for i, mgd_idx in enumerate(indices):
|
| 199 |
local = self._local(self._managed[mgd_idx])
|
| 200 |
off, n = offsets[i]
|
| 201 |
+
local.reshape(-1).copy_(cpu_flat[off:off + n],
|
| 202 |
+
non_blocking=True)
|
| 203 |
|
| 204 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 205 |
|
|
@@ -163,9 +163,10 @@ def construct_shard_mesh(
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
-
#
|
| 167 |
-
#
|
| 168 |
-
#
|
|
|
|
| 169 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 170 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 171 |
if key not in _ranks_to_dist_cache:
|
|
@@ -207,22 +208,25 @@ def construct_shard_mesh(
|
|
| 207 |
assert len(shard_placements) == len(set(shard_placements))
|
| 208 |
|
| 209 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 210 |
-
#
|
| 211 |
-
#
|
|
|
|
|
|
|
| 212 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 213 |
return (*t.shape, *t.flatten().tolist())
|
| 214 |
|
| 215 |
my_key = None
|
| 216 |
for sm in shard_meshes:
|
| 217 |
-
key = _cache_key(sm)
|
| 218 |
if (my_rank == sm).any().item():
|
|
|
|
| 219 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 220 |
my_key = key
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
| 227 |
|
| 228 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
+
# Reuses the mesh's existing ProcessGroup directly, avoiding the
|
| 167 |
+
# overhead of dist.new_group(). The standard path below also handles
|
| 168 |
+
# subset calls safely via use_local_synchronization=True, but this
|
| 169 |
+
# fast path is still beneficial for the common 1D shard case.
|
| 170 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 171 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 172 |
if key not in _ranks_to_dist_cache:
|
|
|
|
| 208 |
assert len(shard_placements) == len(set(shard_placements))
|
| 209 |
|
| 210 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 211 |
+
# Each rank only creates the group it belongs to, using
|
| 212 |
+
# use_local_synchronization=True so that only group members need to
|
| 213 |
+
# coordinate. This avoids deadlocks when different PP stages call
|
| 214 |
+
# construct_shard_mesh for different parameters.
|
| 215 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 216 |
return (*t.shape, *t.flatten().tolist())
|
| 217 |
|
| 218 |
my_key = None
|
| 219 |
for sm in shard_meshes:
|
|
|
|
| 220 |
if (my_rank == sm).any().item():
|
| 221 |
+
key = _cache_key(sm)
|
| 222 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 223 |
my_key = key
|
| 224 |
+
if key not in _ranks_to_dist_cache:
|
| 225 |
+
pg = dist.new_group(sm.flatten().tolist(),
|
| 226 |
+
use_local_synchronization=True)
|
| 227 |
+
_ranks_to_dist_cache[key] = (
|
| 228 |
+
DeviceMesh(device_type="cuda", mesh=sm),
|
| 229 |
+
pg,
|
| 230 |
+
)
|
| 231 |
|
| 232 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
@@ -8,7 +8,7 @@ import torch.distributed as dist
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
-
from .adamw import step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
-
expert_keys=None
|
| 211 |
-
cpu_offload=False):
|
| 212 |
defaults = dict(
|
| 213 |
lr=lr,
|
| 214 |
weight_decay=weight_decay,
|
|
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 229 |
if param_group.get("use_muon", None) is None:
|
| 230 |
raise ValueError(
|
| 231 |
error_message.format(idx=_idx) + instruction_code)
|
| 232 |
-
|
| 233 |
super().__init__(params, defaults)
|
| 234 |
|
| 235 |
self.debug = debug
|
|
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 243 |
self.chunk_size = chunk_size
|
| 244 |
self.use_distributed_muon = use_distributed_muon
|
| 245 |
self.expert_keys = expert_keys
|
| 246 |
-
self.cpu_offload =
|
| 247 |
-
self._cpu_offload_pool
|
| 248 |
self._offload_initialized = False
|
| 249 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 250 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 1008 |
# D2H: offload optimizer states to CPU after computation.
|
| 1009 |
if self.cpu_offload:
|
| 1010 |
if not self._offload_initialized:
|
|
|
|
|
|
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
-
#
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
-
def
|
| 1022 |
-
"""
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1032 |
self._cpu_offload_pool.reload()
|
| 1033 |
torch.cuda.current_stream().synchronize()
|
| 1034 |
-
|
| 1035 |
-
|
| 1036 |
-
|
| 1037 |
-
# (which frees GPU storage on the originals via resize_(0)).
|
| 1038 |
-
for k in sd["state"]:
|
| 1039 |
-
sd["state"][k] = {
|
| 1040 |
-
sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
|
| 1041 |
-
for sk, sv in sd["state"][k].items()
|
| 1042 |
-
}
|
| 1043 |
-
self._cpu_offload_pool.offload()
|
| 1044 |
-
return sd
|
| 1045 |
|
| 1046 |
-
|
| 1047 |
-
|
|
|
|
| 1048 |
|
| 1049 |
-
|
| 1050 |
-
|
| 1051 |
-
|
| 1052 |
-
|
| 1053 |
-
# If states were offloaded, reload first so storage sizes are
|
| 1054 |
-
# correct for super().load_state_dict() to overwrite.
|
| 1055 |
-
if self.cpu_offload and self._offload_initialized:
|
| 1056 |
-
self._cpu_offload_pool.reload()
|
| 1057 |
-
torch.cuda.current_stream().synchronize()
|
| 1058 |
|
|
|
|
|
|
|
|
|
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
-
|
| 1064 |
-
|
| 1065 |
-
self._offload_initialized = False
|
| 1066 |
-
self._register_states_for_offload()
|
| 1067 |
-
self._offload_initialized = True
|
| 1068 |
-
self._cpu_offload_pool.offload()
|
|
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
+
from .adamw import _placement_cache, _tensor_cache, step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
+
expert_keys=None):
|
|
|
|
| 211 |
defaults = dict(
|
| 212 |
lr=lr,
|
| 213 |
weight_decay=weight_decay,
|
|
|
|
| 228 |
if param_group.get("use_muon", None) is None:
|
| 229 |
raise ValueError(
|
| 230 |
error_message.format(idx=_idx) + instruction_code)
|
|
|
|
| 231 |
super().__init__(params, defaults)
|
| 232 |
|
| 233 |
self.debug = debug
|
|
|
|
| 241 |
self.chunk_size = chunk_size
|
| 242 |
self.use_distributed_muon = use_distributed_muon
|
| 243 |
self.expert_keys = expert_keys
|
| 244 |
+
self.cpu_offload = False
|
| 245 |
+
self._cpu_offload_pool: CPUOffloadPool | None = None
|
| 246 |
self._offload_initialized = False
|
| 247 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 248 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
|
|
| 1006 |
# D2H: offload optimizer states to CPU after computation.
|
| 1007 |
if self.cpu_offload:
|
| 1008 |
if not self._offload_initialized:
|
| 1009 |
+
if self._cpu_offload_pool is None:
|
| 1010 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
+
# CPU offload public helpers
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
+
def turn_on_cpu_offload(self):
|
| 1022 |
+
"""Enable CPU offload for optimizer states."""
|
| 1023 |
+
if self.cpu_offload:
|
| 1024 |
+
return
|
| 1025 |
+
logger.info("[Muon] turn_on_cpu_offload")
|
| 1026 |
+
self.cpu_offload = True
|
| 1027 |
+
if not self.state:
|
| 1028 |
+
return
|
| 1029 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1030 |
+
self._offload_initialized = False
|
| 1031 |
+
self._register_states_for_offload()
|
| 1032 |
+
self._offload_initialized = True
|
| 1033 |
+
self._cpu_offload_pool.offload()
|
| 1034 |
+
|
| 1035 |
+
def turn_off_cpu_offload(self):
|
| 1036 |
+
"""Disable CPU offload and keep optimizer states resident on GPU."""
|
| 1037 |
+
if not self.cpu_offload:
|
| 1038 |
+
return
|
| 1039 |
+
logger.info("[Muon] turn_off_cpu_offload")
|
| 1040 |
+
if self._offload_initialized:
|
| 1041 |
self._cpu_offload_pool.reload()
|
| 1042 |
torch.cuda.current_stream().synchronize()
|
| 1043 |
+
self._cpu_offload_pool = None
|
| 1044 |
+
self._offload_initialized = False
|
| 1045 |
+
self.cpu_offload = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1046 |
|
| 1047 |
+
# ------------------------------------------------------------------
|
| 1048 |
+
# Checkpoint support for cpu_offload
|
| 1049 |
+
# ------------------------------------------------------------------
|
| 1050 |
|
| 1051 |
+
def state_dict(self) -> dict:
|
| 1052 |
+
if self.cpu_offload:
|
| 1053 |
+
raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
|
| 1054 |
+
return super().state_dict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1055 |
|
| 1056 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
| 1057 |
+
if self.cpu_offload:
|
| 1058 |
+
raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
+
# Invalidate adamw.py's module-level tensor caches so that
|
| 1062 |
+
# the next step rebuilds them with the newly loaded state tensors.
|
| 1063 |
+
_placement_cache.clear()
|
| 1064 |
+
_tensor_cache.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_b68ea5b_dirty
|
| 3 |
+
ops = torch.ops._optimizer_b68ea5b_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_b68ea5b_dirty::{op_name}"
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 2004728
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f72217eb59ce93935f593b0fdfc7f3bfc4e05f18ad9d5384c2325b27ad7ff136
|
| 3 |
size 2004728
|
|
@@ -68,7 +68,11 @@ class CPUOffloadPool:
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
self._managed.append(tensor)
|
| 73 |
|
| 74 |
# ------------------------------------------------------------------
|
|
@@ -89,7 +93,10 @@ class CPUOffloadPool:
|
|
| 89 |
indices.append(idx)
|
| 90 |
offsets.append((off, n))
|
| 91 |
off += n
|
| 92 |
-
cpu_flat = torch.empty(off,
|
|
|
|
|
|
|
|
|
|
| 93 |
self._groups[dtype] = {
|
| 94 |
"indices": indices,
|
| 95 |
"offsets": offsets,
|
|
@@ -117,8 +124,7 @@ class CPUOffloadPool:
|
|
| 117 |
self._ensure_stream()
|
| 118 |
|
| 119 |
# Offload stream waits for compute to finish.
|
| 120 |
-
compute_event = torch.cuda.current_stream(
|
| 121 |
-
self._device).record_event()
|
| 122 |
self._offload_stream.wait_event(compute_event)
|
| 123 |
|
| 124 |
offloaded_bytes = 0
|
|
@@ -134,15 +140,23 @@ class CPUOffloadPool:
|
|
| 134 |
for i, mgd_idx in enumerate(indices):
|
| 135 |
local = self._local(self._managed[mgd_idx])
|
| 136 |
off, n = offsets[i]
|
| 137 |
-
cpu_flat[off:off + n].copy_(
|
| 138 |
-
|
| 139 |
|
| 140 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 141 |
|
| 142 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 143 |
self._offload_stream.synchronize()
|
| 144 |
for t in self._managed:
|
| 145 |
-
self._local(t).untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
if not self._logged:
|
| 148 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
@@ -165,7 +179,14 @@ class CPUOffloadPool:
|
|
| 165 |
# Re-allocate all GPU storages first.
|
| 166 |
for t in self._managed:
|
| 167 |
local = self._local(t)
|
| 168 |
-
local.untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 171 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
@@ -177,8 +198,8 @@ class CPUOffloadPool:
|
|
| 177 |
for i, mgd_idx in enumerate(indices):
|
| 178 |
local = self._local(self._managed[mgd_idx])
|
| 179 |
off, n = offsets[i]
|
| 180 |
-
local.reshape(-1).copy_(
|
| 181 |
-
|
| 182 |
|
| 183 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 184 |
|
|
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
+
storage = local.untyped_storage()
|
| 72 |
+
# Skip tensors with empty storage (e.g. empty FSDP shards)
|
| 73 |
+
if storage.size() == 0:
|
| 74 |
+
return
|
| 75 |
+
self._storage_nbytes[tid] = storage.size()
|
| 76 |
self._managed.append(tensor)
|
| 77 |
|
| 78 |
# ------------------------------------------------------------------
|
|
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
+
cpu_flat = torch.empty(off,
|
| 97 |
+
dtype=dtype,
|
| 98 |
+
device="cpu",
|
| 99 |
+
pin_memory=True)
|
| 100 |
self._groups[dtype] = {
|
| 101 |
"indices": indices,
|
| 102 |
"offsets": offsets,
|
|
|
|
| 124 |
self._ensure_stream()
|
| 125 |
|
| 126 |
# Offload stream waits for compute to finish.
|
| 127 |
+
compute_event = torch.cuda.current_stream(self._device).record_event()
|
|
|
|
| 128 |
self._offload_stream.wait_event(compute_event)
|
| 129 |
|
| 130 |
offloaded_bytes = 0
|
|
|
|
| 140 |
for i, mgd_idx in enumerate(indices):
|
| 141 |
local = self._local(self._managed[mgd_idx])
|
| 142 |
off, n = offsets[i]
|
| 143 |
+
cpu_flat[off:off + n].copy_(local.reshape(-1),
|
| 144 |
+
non_blocking=True)
|
| 145 |
|
| 146 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 147 |
|
| 148 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 149 |
self._offload_stream.synchronize()
|
| 150 |
for t in self._managed:
|
| 151 |
+
storage = self._local(t).untyped_storage()
|
| 152 |
+
if storage.size() != 0:
|
| 153 |
+
storage.resize_(0)
|
| 154 |
+
else:
|
| 155 |
+
raise RuntimeError(
|
| 156 |
+
f"Tensor storage is already freed (size=0) before offload. "
|
| 157 |
+
f"This indicates a double-free or external interference. "
|
| 158 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 159 |
+
)
|
| 160 |
|
| 161 |
if not self._logged:
|
| 162 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
|
|
| 179 |
# Re-allocate all GPU storages first.
|
| 180 |
for t in self._managed:
|
| 181 |
local = self._local(t)
|
| 182 |
+
storage = local.untyped_storage()
|
| 183 |
+
if storage.size() != 0:
|
| 184 |
+
raise RuntimeError(
|
| 185 |
+
f"Storage should have been freed (size=0) before reload, "
|
| 186 |
+
f"but got size={storage.size()}. "
|
| 187 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 188 |
+
)
|
| 189 |
+
storage.resize_(self._storage_nbytes[id(t)])
|
| 190 |
|
| 191 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 192 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
|
|
| 198 |
for i, mgd_idx in enumerate(indices):
|
| 199 |
local = self._local(self._managed[mgd_idx])
|
| 200 |
off, n = offsets[i]
|
| 201 |
+
local.reshape(-1).copy_(cpu_flat[off:off + n],
|
| 202 |
+
non_blocking=True)
|
| 203 |
|
| 204 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 205 |
|
|
@@ -163,9 +163,10 @@ def construct_shard_mesh(
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
-
#
|
| 167 |
-
#
|
| 168 |
-
#
|
|
|
|
| 169 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 170 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 171 |
if key not in _ranks_to_dist_cache:
|
|
@@ -207,22 +208,25 @@ def construct_shard_mesh(
|
|
| 207 |
assert len(shard_placements) == len(set(shard_placements))
|
| 208 |
|
| 209 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 210 |
-
#
|
| 211 |
-
#
|
|
|
|
|
|
|
| 212 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 213 |
return (*t.shape, *t.flatten().tolist())
|
| 214 |
|
| 215 |
my_key = None
|
| 216 |
for sm in shard_meshes:
|
| 217 |
-
key = _cache_key(sm)
|
| 218 |
if (my_rank == sm).any().item():
|
|
|
|
| 219 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 220 |
my_key = key
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
| 227 |
|
| 228 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
+
# Reuses the mesh's existing ProcessGroup directly, avoiding the
|
| 167 |
+
# overhead of dist.new_group(). The standard path below also handles
|
| 168 |
+
# subset calls safely via use_local_synchronization=True, but this
|
| 169 |
+
# fast path is still beneficial for the common 1D shard case.
|
| 170 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 171 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 172 |
if key not in _ranks_to_dist_cache:
|
|
|
|
| 208 |
assert len(shard_placements) == len(set(shard_placements))
|
| 209 |
|
| 210 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 211 |
+
# Each rank only creates the group it belongs to, using
|
| 212 |
+
# use_local_synchronization=True so that only group members need to
|
| 213 |
+
# coordinate. This avoids deadlocks when different PP stages call
|
| 214 |
+
# construct_shard_mesh for different parameters.
|
| 215 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 216 |
return (*t.shape, *t.flatten().tolist())
|
| 217 |
|
| 218 |
my_key = None
|
| 219 |
for sm in shard_meshes:
|
|
|
|
| 220 |
if (my_rank == sm).any().item():
|
| 221 |
+
key = _cache_key(sm)
|
| 222 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 223 |
my_key = key
|
| 224 |
+
if key not in _ranks_to_dist_cache:
|
| 225 |
+
pg = dist.new_group(sm.flatten().tolist(),
|
| 226 |
+
use_local_synchronization=True)
|
| 227 |
+
_ranks_to_dist_cache[key] = (
|
| 228 |
+
DeviceMesh(device_type="cuda", mesh=sm),
|
| 229 |
+
pg,
|
| 230 |
+
)
|
| 231 |
|
| 232 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
@@ -8,7 +8,7 @@ import torch.distributed as dist
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
-
from .adamw import step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
-
expert_keys=None
|
| 211 |
-
cpu_offload=False):
|
| 212 |
defaults = dict(
|
| 213 |
lr=lr,
|
| 214 |
weight_decay=weight_decay,
|
|
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 229 |
if param_group.get("use_muon", None) is None:
|
| 230 |
raise ValueError(
|
| 231 |
error_message.format(idx=_idx) + instruction_code)
|
| 232 |
-
|
| 233 |
super().__init__(params, defaults)
|
| 234 |
|
| 235 |
self.debug = debug
|
|
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 243 |
self.chunk_size = chunk_size
|
| 244 |
self.use_distributed_muon = use_distributed_muon
|
| 245 |
self.expert_keys = expert_keys
|
| 246 |
-
self.cpu_offload =
|
| 247 |
-
self._cpu_offload_pool
|
| 248 |
self._offload_initialized = False
|
| 249 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 250 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 1008 |
# D2H: offload optimizer states to CPU after computation.
|
| 1009 |
if self.cpu_offload:
|
| 1010 |
if not self._offload_initialized:
|
|
|
|
|
|
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
-
#
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
-
def
|
| 1022 |
-
"""
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1032 |
self._cpu_offload_pool.reload()
|
| 1033 |
torch.cuda.current_stream().synchronize()
|
| 1034 |
-
|
| 1035 |
-
|
| 1036 |
-
|
| 1037 |
-
# (which frees GPU storage on the originals via resize_(0)).
|
| 1038 |
-
for k in sd["state"]:
|
| 1039 |
-
sd["state"][k] = {
|
| 1040 |
-
sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
|
| 1041 |
-
for sk, sv in sd["state"][k].items()
|
| 1042 |
-
}
|
| 1043 |
-
self._cpu_offload_pool.offload()
|
| 1044 |
-
return sd
|
| 1045 |
|
| 1046 |
-
|
| 1047 |
-
|
|
|
|
| 1048 |
|
| 1049 |
-
|
| 1050 |
-
|
| 1051 |
-
|
| 1052 |
-
|
| 1053 |
-
# If states were offloaded, reload first so storage sizes are
|
| 1054 |
-
# correct for super().load_state_dict() to overwrite.
|
| 1055 |
-
if self.cpu_offload and self._offload_initialized:
|
| 1056 |
-
self._cpu_offload_pool.reload()
|
| 1057 |
-
torch.cuda.current_stream().synchronize()
|
| 1058 |
|
|
|
|
|
|
|
|
|
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
-
|
| 1064 |
-
|
| 1065 |
-
self._offload_initialized = False
|
| 1066 |
-
self._register_states_for_offload()
|
| 1067 |
-
self._offload_initialized = True
|
| 1068 |
-
self._cpu_offload_pool.offload()
|
|
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
+
from .adamw import _placement_cache, _tensor_cache, step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
+
expert_keys=None):
|
|
|
|
| 211 |
defaults = dict(
|
| 212 |
lr=lr,
|
| 213 |
weight_decay=weight_decay,
|
|
|
|
| 228 |
if param_group.get("use_muon", None) is None:
|
| 229 |
raise ValueError(
|
| 230 |
error_message.format(idx=_idx) + instruction_code)
|
|
|
|
| 231 |
super().__init__(params, defaults)
|
| 232 |
|
| 233 |
self.debug = debug
|
|
|
|
| 241 |
self.chunk_size = chunk_size
|
| 242 |
self.use_distributed_muon = use_distributed_muon
|
| 243 |
self.expert_keys = expert_keys
|
| 244 |
+
self.cpu_offload = False
|
| 245 |
+
self._cpu_offload_pool: CPUOffloadPool | None = None
|
| 246 |
self._offload_initialized = False
|
| 247 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 248 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
|
|
| 1006 |
# D2H: offload optimizer states to CPU after computation.
|
| 1007 |
if self.cpu_offload:
|
| 1008 |
if not self._offload_initialized:
|
| 1009 |
+
if self._cpu_offload_pool is None:
|
| 1010 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
+
# CPU offload public helpers
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
+
def turn_on_cpu_offload(self):
|
| 1022 |
+
"""Enable CPU offload for optimizer states."""
|
| 1023 |
+
if self.cpu_offload:
|
| 1024 |
+
return
|
| 1025 |
+
logger.info("[Muon] turn_on_cpu_offload")
|
| 1026 |
+
self.cpu_offload = True
|
| 1027 |
+
if not self.state:
|
| 1028 |
+
return
|
| 1029 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1030 |
+
self._offload_initialized = False
|
| 1031 |
+
self._register_states_for_offload()
|
| 1032 |
+
self._offload_initialized = True
|
| 1033 |
+
self._cpu_offload_pool.offload()
|
| 1034 |
+
|
| 1035 |
+
def turn_off_cpu_offload(self):
|
| 1036 |
+
"""Disable CPU offload and keep optimizer states resident on GPU."""
|
| 1037 |
+
if not self.cpu_offload:
|
| 1038 |
+
return
|
| 1039 |
+
logger.info("[Muon] turn_off_cpu_offload")
|
| 1040 |
+
if self._offload_initialized:
|
| 1041 |
self._cpu_offload_pool.reload()
|
| 1042 |
torch.cuda.current_stream().synchronize()
|
| 1043 |
+
self._cpu_offload_pool = None
|
| 1044 |
+
self._offload_initialized = False
|
| 1045 |
+
self.cpu_offload = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1046 |
|
| 1047 |
+
# ------------------------------------------------------------------
|
| 1048 |
+
# Checkpoint support for cpu_offload
|
| 1049 |
+
# ------------------------------------------------------------------
|
| 1050 |
|
| 1051 |
+
def state_dict(self) -> dict:
|
| 1052 |
+
if self.cpu_offload:
|
| 1053 |
+
raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
|
| 1054 |
+
return super().state_dict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1055 |
|
| 1056 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
| 1057 |
+
if self.cpu_offload:
|
| 1058 |
+
raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
+
# Invalidate adamw.py's module-level tensor caches so that
|
| 1062 |
+
# the next step rebuilds them with the newly loaded state tensors.
|
| 1063 |
+
_placement_cache.clear()
|
| 1064 |
+
_tensor_cache.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_b68ea5b_dirty
|
| 3 |
+
ops = torch.ops._optimizer_b68ea5b_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_b68ea5b_dirty::{op_name}"
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1866400
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:df63b7dafe62fa5e910de1123729cbe3496015e2c0110785d9bd510bf65c2eaa
|
| 3 |
size 1866400
|
|
@@ -68,7 +68,11 @@ class CPUOffloadPool:
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
self._managed.append(tensor)
|
| 73 |
|
| 74 |
# ------------------------------------------------------------------
|
|
@@ -89,7 +93,10 @@ class CPUOffloadPool:
|
|
| 89 |
indices.append(idx)
|
| 90 |
offsets.append((off, n))
|
| 91 |
off += n
|
| 92 |
-
cpu_flat = torch.empty(off,
|
|
|
|
|
|
|
|
|
|
| 93 |
self._groups[dtype] = {
|
| 94 |
"indices": indices,
|
| 95 |
"offsets": offsets,
|
|
@@ -117,8 +124,7 @@ class CPUOffloadPool:
|
|
| 117 |
self._ensure_stream()
|
| 118 |
|
| 119 |
# Offload stream waits for compute to finish.
|
| 120 |
-
compute_event = torch.cuda.current_stream(
|
| 121 |
-
self._device).record_event()
|
| 122 |
self._offload_stream.wait_event(compute_event)
|
| 123 |
|
| 124 |
offloaded_bytes = 0
|
|
@@ -134,15 +140,23 @@ class CPUOffloadPool:
|
|
| 134 |
for i, mgd_idx in enumerate(indices):
|
| 135 |
local = self._local(self._managed[mgd_idx])
|
| 136 |
off, n = offsets[i]
|
| 137 |
-
cpu_flat[off:off + n].copy_(
|
| 138 |
-
|
| 139 |
|
| 140 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 141 |
|
| 142 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 143 |
self._offload_stream.synchronize()
|
| 144 |
for t in self._managed:
|
| 145 |
-
self._local(t).untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
if not self._logged:
|
| 148 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
@@ -165,7 +179,14 @@ class CPUOffloadPool:
|
|
| 165 |
# Re-allocate all GPU storages first.
|
| 166 |
for t in self._managed:
|
| 167 |
local = self._local(t)
|
| 168 |
-
local.untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 171 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
@@ -177,8 +198,8 @@ class CPUOffloadPool:
|
|
| 177 |
for i, mgd_idx in enumerate(indices):
|
| 178 |
local = self._local(self._managed[mgd_idx])
|
| 179 |
off, n = offsets[i]
|
| 180 |
-
local.reshape(-1).copy_(
|
| 181 |
-
|
| 182 |
|
| 183 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 184 |
|
|
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
+
storage = local.untyped_storage()
|
| 72 |
+
# Skip tensors with empty storage (e.g. empty FSDP shards)
|
| 73 |
+
if storage.size() == 0:
|
| 74 |
+
return
|
| 75 |
+
self._storage_nbytes[tid] = storage.size()
|
| 76 |
self._managed.append(tensor)
|
| 77 |
|
| 78 |
# ------------------------------------------------------------------
|
|
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
+
cpu_flat = torch.empty(off,
|
| 97 |
+
dtype=dtype,
|
| 98 |
+
device="cpu",
|
| 99 |
+
pin_memory=True)
|
| 100 |
self._groups[dtype] = {
|
| 101 |
"indices": indices,
|
| 102 |
"offsets": offsets,
|
|
|
|
| 124 |
self._ensure_stream()
|
| 125 |
|
| 126 |
# Offload stream waits for compute to finish.
|
| 127 |
+
compute_event = torch.cuda.current_stream(self._device).record_event()
|
|
|
|
| 128 |
self._offload_stream.wait_event(compute_event)
|
| 129 |
|
| 130 |
offloaded_bytes = 0
|
|
|
|
| 140 |
for i, mgd_idx in enumerate(indices):
|
| 141 |
local = self._local(self._managed[mgd_idx])
|
| 142 |
off, n = offsets[i]
|
| 143 |
+
cpu_flat[off:off + n].copy_(local.reshape(-1),
|
| 144 |
+
non_blocking=True)
|
| 145 |
|
| 146 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 147 |
|
| 148 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 149 |
self._offload_stream.synchronize()
|
| 150 |
for t in self._managed:
|
| 151 |
+
storage = self._local(t).untyped_storage()
|
| 152 |
+
if storage.size() != 0:
|
| 153 |
+
storage.resize_(0)
|
| 154 |
+
else:
|
| 155 |
+
raise RuntimeError(
|
| 156 |
+
f"Tensor storage is already freed (size=0) before offload. "
|
| 157 |
+
f"This indicates a double-free or external interference. "
|
| 158 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 159 |
+
)
|
| 160 |
|
| 161 |
if not self._logged:
|
| 162 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
|
|
| 179 |
# Re-allocate all GPU storages first.
|
| 180 |
for t in self._managed:
|
| 181 |
local = self._local(t)
|
| 182 |
+
storage = local.untyped_storage()
|
| 183 |
+
if storage.size() != 0:
|
| 184 |
+
raise RuntimeError(
|
| 185 |
+
f"Storage should have been freed (size=0) before reload, "
|
| 186 |
+
f"but got size={storage.size()}. "
|
| 187 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 188 |
+
)
|
| 189 |
+
storage.resize_(self._storage_nbytes[id(t)])
|
| 190 |
|
| 191 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 192 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
|
|
| 198 |
for i, mgd_idx in enumerate(indices):
|
| 199 |
local = self._local(self._managed[mgd_idx])
|
| 200 |
off, n = offsets[i]
|
| 201 |
+
local.reshape(-1).copy_(cpu_flat[off:off + n],
|
| 202 |
+
non_blocking=True)
|
| 203 |
|
| 204 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 205 |
|
|
@@ -163,9 +163,10 @@ def construct_shard_mesh(
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
-
#
|
| 167 |
-
#
|
| 168 |
-
#
|
|
|
|
| 169 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 170 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 171 |
if key not in _ranks_to_dist_cache:
|
|
@@ -207,22 +208,25 @@ def construct_shard_mesh(
|
|
| 207 |
assert len(shard_placements) == len(set(shard_placements))
|
| 208 |
|
| 209 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 210 |
-
#
|
| 211 |
-
#
|
|
|
|
|
|
|
| 212 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 213 |
return (*t.shape, *t.flatten().tolist())
|
| 214 |
|
| 215 |
my_key = None
|
| 216 |
for sm in shard_meshes:
|
| 217 |
-
key = _cache_key(sm)
|
| 218 |
if (my_rank == sm).any().item():
|
|
|
|
| 219 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 220 |
my_key = key
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
| 227 |
|
| 228 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
+
# Reuses the mesh's existing ProcessGroup directly, avoiding the
|
| 167 |
+
# overhead of dist.new_group(). The standard path below also handles
|
| 168 |
+
# subset calls safely via use_local_synchronization=True, but this
|
| 169 |
+
# fast path is still beneficial for the common 1D shard case.
|
| 170 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 171 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 172 |
if key not in _ranks_to_dist_cache:
|
|
|
|
| 208 |
assert len(shard_placements) == len(set(shard_placements))
|
| 209 |
|
| 210 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 211 |
+
# Each rank only creates the group it belongs to, using
|
| 212 |
+
# use_local_synchronization=True so that only group members need to
|
| 213 |
+
# coordinate. This avoids deadlocks when different PP stages call
|
| 214 |
+
# construct_shard_mesh for different parameters.
|
| 215 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 216 |
return (*t.shape, *t.flatten().tolist())
|
| 217 |
|
| 218 |
my_key = None
|
| 219 |
for sm in shard_meshes:
|
|
|
|
| 220 |
if (my_rank == sm).any().item():
|
| 221 |
+
key = _cache_key(sm)
|
| 222 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 223 |
my_key = key
|
| 224 |
+
if key not in _ranks_to_dist_cache:
|
| 225 |
+
pg = dist.new_group(sm.flatten().tolist(),
|
| 226 |
+
use_local_synchronization=True)
|
| 227 |
+
_ranks_to_dist_cache[key] = (
|
| 228 |
+
DeviceMesh(device_type="cuda", mesh=sm),
|
| 229 |
+
pg,
|
| 230 |
+
)
|
| 231 |
|
| 232 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
@@ -8,7 +8,7 @@ import torch.distributed as dist
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
-
from .adamw import step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
-
expert_keys=None
|
| 211 |
-
cpu_offload=False):
|
| 212 |
defaults = dict(
|
| 213 |
lr=lr,
|
| 214 |
weight_decay=weight_decay,
|
|
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 229 |
if param_group.get("use_muon", None) is None:
|
| 230 |
raise ValueError(
|
| 231 |
error_message.format(idx=_idx) + instruction_code)
|
| 232 |
-
|
| 233 |
super().__init__(params, defaults)
|
| 234 |
|
| 235 |
self.debug = debug
|
|
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 243 |
self.chunk_size = chunk_size
|
| 244 |
self.use_distributed_muon = use_distributed_muon
|
| 245 |
self.expert_keys = expert_keys
|
| 246 |
-
self.cpu_offload =
|
| 247 |
-
self._cpu_offload_pool
|
| 248 |
self._offload_initialized = False
|
| 249 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 250 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 1008 |
# D2H: offload optimizer states to CPU after computation.
|
| 1009 |
if self.cpu_offload:
|
| 1010 |
if not self._offload_initialized:
|
|
|
|
|
|
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
-
#
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
-
def
|
| 1022 |
-
"""
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1032 |
self._cpu_offload_pool.reload()
|
| 1033 |
torch.cuda.current_stream().synchronize()
|
| 1034 |
-
|
| 1035 |
-
|
| 1036 |
-
|
| 1037 |
-
# (which frees GPU storage on the originals via resize_(0)).
|
| 1038 |
-
for k in sd["state"]:
|
| 1039 |
-
sd["state"][k] = {
|
| 1040 |
-
sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
|
| 1041 |
-
for sk, sv in sd["state"][k].items()
|
| 1042 |
-
}
|
| 1043 |
-
self._cpu_offload_pool.offload()
|
| 1044 |
-
return sd
|
| 1045 |
|
| 1046 |
-
|
| 1047 |
-
|
|
|
|
| 1048 |
|
| 1049 |
-
|
| 1050 |
-
|
| 1051 |
-
|
| 1052 |
-
|
| 1053 |
-
# If states were offloaded, reload first so storage sizes are
|
| 1054 |
-
# correct for super().load_state_dict() to overwrite.
|
| 1055 |
-
if self.cpu_offload and self._offload_initialized:
|
| 1056 |
-
self._cpu_offload_pool.reload()
|
| 1057 |
-
torch.cuda.current_stream().synchronize()
|
| 1058 |
|
|
|
|
|
|
|
|
|
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
-
|
| 1064 |
-
|
| 1065 |
-
self._offload_initialized = False
|
| 1066 |
-
self._register_states_for_offload()
|
| 1067 |
-
self._offload_initialized = True
|
| 1068 |
-
self._cpu_offload_pool.offload()
|
|
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
+
from .adamw import _placement_cache, _tensor_cache, step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
+
expert_keys=None):
|
|
|
|
| 211 |
defaults = dict(
|
| 212 |
lr=lr,
|
| 213 |
weight_decay=weight_decay,
|
|
|
|
| 228 |
if param_group.get("use_muon", None) is None:
|
| 229 |
raise ValueError(
|
| 230 |
error_message.format(idx=_idx) + instruction_code)
|
|
|
|
| 231 |
super().__init__(params, defaults)
|
| 232 |
|
| 233 |
self.debug = debug
|
|
|
|
| 241 |
self.chunk_size = chunk_size
|
| 242 |
self.use_distributed_muon = use_distributed_muon
|
| 243 |
self.expert_keys = expert_keys
|
| 244 |
+
self.cpu_offload = False
|
| 245 |
+
self._cpu_offload_pool: CPUOffloadPool | None = None
|
| 246 |
self._offload_initialized = False
|
| 247 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 248 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
|
|
| 1006 |
# D2H: offload optimizer states to CPU after computation.
|
| 1007 |
if self.cpu_offload:
|
| 1008 |
if not self._offload_initialized:
|
| 1009 |
+
if self._cpu_offload_pool is None:
|
| 1010 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
+
# CPU offload public helpers
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
+
def turn_on_cpu_offload(self):
|
| 1022 |
+
"""Enable CPU offload for optimizer states."""
|
| 1023 |
+
if self.cpu_offload:
|
| 1024 |
+
return
|
| 1025 |
+
logger.info("[Muon] turn_on_cpu_offload")
|
| 1026 |
+
self.cpu_offload = True
|
| 1027 |
+
if not self.state:
|
| 1028 |
+
return
|
| 1029 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1030 |
+
self._offload_initialized = False
|
| 1031 |
+
self._register_states_for_offload()
|
| 1032 |
+
self._offload_initialized = True
|
| 1033 |
+
self._cpu_offload_pool.offload()
|
| 1034 |
+
|
| 1035 |
+
def turn_off_cpu_offload(self):
|
| 1036 |
+
"""Disable CPU offload and keep optimizer states resident on GPU."""
|
| 1037 |
+
if not self.cpu_offload:
|
| 1038 |
+
return
|
| 1039 |
+
logger.info("[Muon] turn_off_cpu_offload")
|
| 1040 |
+
if self._offload_initialized:
|
| 1041 |
self._cpu_offload_pool.reload()
|
| 1042 |
torch.cuda.current_stream().synchronize()
|
| 1043 |
+
self._cpu_offload_pool = None
|
| 1044 |
+
self._offload_initialized = False
|
| 1045 |
+
self.cpu_offload = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1046 |
|
| 1047 |
+
# ------------------------------------------------------------------
|
| 1048 |
+
# Checkpoint support for cpu_offload
|
| 1049 |
+
# ------------------------------------------------------------------
|
| 1050 |
|
| 1051 |
+
def state_dict(self) -> dict:
|
| 1052 |
+
if self.cpu_offload:
|
| 1053 |
+
raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
|
| 1054 |
+
return super().state_dict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1055 |
|
| 1056 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
| 1057 |
+
if self.cpu_offload:
|
| 1058 |
+
raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
+
# Invalidate adamw.py's module-level tensor caches so that
|
| 1062 |
+
# the next step rebuilds them with the newly loaded state tensors.
|
| 1063 |
+
_placement_cache.clear()
|
| 1064 |
+
_tensor_cache.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_b68ea5b_dirty
|
| 3 |
+
ops = torch.ops._optimizer_b68ea5b_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_b68ea5b_dirty::{op_name}"
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1866112
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7215ee4575fa44f482a98e563c5af2d60089e36d32fe8a3dcffe3fb5f587300f
|
| 3 |
size 1866112
|
|
@@ -68,7 +68,11 @@ class CPUOffloadPool:
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
self._managed.append(tensor)
|
| 73 |
|
| 74 |
# ------------------------------------------------------------------
|
|
@@ -89,7 +93,10 @@ class CPUOffloadPool:
|
|
| 89 |
indices.append(idx)
|
| 90 |
offsets.append((off, n))
|
| 91 |
off += n
|
| 92 |
-
cpu_flat = torch.empty(off,
|
|
|
|
|
|
|
|
|
|
| 93 |
self._groups[dtype] = {
|
| 94 |
"indices": indices,
|
| 95 |
"offsets": offsets,
|
|
@@ -117,8 +124,7 @@ class CPUOffloadPool:
|
|
| 117 |
self._ensure_stream()
|
| 118 |
|
| 119 |
# Offload stream waits for compute to finish.
|
| 120 |
-
compute_event = torch.cuda.current_stream(
|
| 121 |
-
self._device).record_event()
|
| 122 |
self._offload_stream.wait_event(compute_event)
|
| 123 |
|
| 124 |
offloaded_bytes = 0
|
|
@@ -134,15 +140,23 @@ class CPUOffloadPool:
|
|
| 134 |
for i, mgd_idx in enumerate(indices):
|
| 135 |
local = self._local(self._managed[mgd_idx])
|
| 136 |
off, n = offsets[i]
|
| 137 |
-
cpu_flat[off:off + n].copy_(
|
| 138 |
-
|
| 139 |
|
| 140 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 141 |
|
| 142 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 143 |
self._offload_stream.synchronize()
|
| 144 |
for t in self._managed:
|
| 145 |
-
self._local(t).untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
if not self._logged:
|
| 148 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
@@ -165,7 +179,14 @@ class CPUOffloadPool:
|
|
| 165 |
# Re-allocate all GPU storages first.
|
| 166 |
for t in self._managed:
|
| 167 |
local = self._local(t)
|
| 168 |
-
local.untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 171 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
@@ -177,8 +198,8 @@ class CPUOffloadPool:
|
|
| 177 |
for i, mgd_idx in enumerate(indices):
|
| 178 |
local = self._local(self._managed[mgd_idx])
|
| 179 |
off, n = offsets[i]
|
| 180 |
-
local.reshape(-1).copy_(
|
| 181 |
-
|
| 182 |
|
| 183 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 184 |
|
|
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
+
storage = local.untyped_storage()
|
| 72 |
+
# Skip tensors with empty storage (e.g. empty FSDP shards)
|
| 73 |
+
if storage.size() == 0:
|
| 74 |
+
return
|
| 75 |
+
self._storage_nbytes[tid] = storage.size()
|
| 76 |
self._managed.append(tensor)
|
| 77 |
|
| 78 |
# ------------------------------------------------------------------
|
|
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
+
cpu_flat = torch.empty(off,
|
| 97 |
+
dtype=dtype,
|
| 98 |
+
device="cpu",
|
| 99 |
+
pin_memory=True)
|
| 100 |
self._groups[dtype] = {
|
| 101 |
"indices": indices,
|
| 102 |
"offsets": offsets,
|
|
|
|
| 124 |
self._ensure_stream()
|
| 125 |
|
| 126 |
# Offload stream waits for compute to finish.
|
| 127 |
+
compute_event = torch.cuda.current_stream(self._device).record_event()
|
|
|
|
| 128 |
self._offload_stream.wait_event(compute_event)
|
| 129 |
|
| 130 |
offloaded_bytes = 0
|
|
|
|
| 140 |
for i, mgd_idx in enumerate(indices):
|
| 141 |
local = self._local(self._managed[mgd_idx])
|
| 142 |
off, n = offsets[i]
|
| 143 |
+
cpu_flat[off:off + n].copy_(local.reshape(-1),
|
| 144 |
+
non_blocking=True)
|
| 145 |
|
| 146 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 147 |
|
| 148 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 149 |
self._offload_stream.synchronize()
|
| 150 |
for t in self._managed:
|
| 151 |
+
storage = self._local(t).untyped_storage()
|
| 152 |
+
if storage.size() != 0:
|
| 153 |
+
storage.resize_(0)
|
| 154 |
+
else:
|
| 155 |
+
raise RuntimeError(
|
| 156 |
+
f"Tensor storage is already freed (size=0) before offload. "
|
| 157 |
+
f"This indicates a double-free or external interference. "
|
| 158 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 159 |
+
)
|
| 160 |
|
| 161 |
if not self._logged:
|
| 162 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
|
|
| 179 |
# Re-allocate all GPU storages first.
|
| 180 |
for t in self._managed:
|
| 181 |
local = self._local(t)
|
| 182 |
+
storage = local.untyped_storage()
|
| 183 |
+
if storage.size() != 0:
|
| 184 |
+
raise RuntimeError(
|
| 185 |
+
f"Storage should have been freed (size=0) before reload, "
|
| 186 |
+
f"but got size={storage.size()}. "
|
| 187 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 188 |
+
)
|
| 189 |
+
storage.resize_(self._storage_nbytes[id(t)])
|
| 190 |
|
| 191 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 192 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
|
|
| 198 |
for i, mgd_idx in enumerate(indices):
|
| 199 |
local = self._local(self._managed[mgd_idx])
|
| 200 |
off, n = offsets[i]
|
| 201 |
+
local.reshape(-1).copy_(cpu_flat[off:off + n],
|
| 202 |
+
non_blocking=True)
|
| 203 |
|
| 204 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 205 |
|
|
@@ -163,9 +163,10 @@ def construct_shard_mesh(
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
-
#
|
| 167 |
-
#
|
| 168 |
-
#
|
|
|
|
| 169 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 170 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 171 |
if key not in _ranks_to_dist_cache:
|
|
@@ -207,22 +208,25 @@ def construct_shard_mesh(
|
|
| 207 |
assert len(shard_placements) == len(set(shard_placements))
|
| 208 |
|
| 209 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 210 |
-
#
|
| 211 |
-
#
|
|
|
|
|
|
|
| 212 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 213 |
return (*t.shape, *t.flatten().tolist())
|
| 214 |
|
| 215 |
my_key = None
|
| 216 |
for sm in shard_meshes:
|
| 217 |
-
key = _cache_key(sm)
|
| 218 |
if (my_rank == sm).any().item():
|
|
|
|
| 219 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 220 |
my_key = key
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
| 227 |
|
| 228 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
+
# Reuses the mesh's existing ProcessGroup directly, avoiding the
|
| 167 |
+
# overhead of dist.new_group(). The standard path below also handles
|
| 168 |
+
# subset calls safely via use_local_synchronization=True, but this
|
| 169 |
+
# fast path is still beneficial for the common 1D shard case.
|
| 170 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 171 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 172 |
if key not in _ranks_to_dist_cache:
|
|
|
|
| 208 |
assert len(shard_placements) == len(set(shard_placements))
|
| 209 |
|
| 210 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 211 |
+
# Each rank only creates the group it belongs to, using
|
| 212 |
+
# use_local_synchronization=True so that only group members need to
|
| 213 |
+
# coordinate. This avoids deadlocks when different PP stages call
|
| 214 |
+
# construct_shard_mesh for different parameters.
|
| 215 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 216 |
return (*t.shape, *t.flatten().tolist())
|
| 217 |
|
| 218 |
my_key = None
|
| 219 |
for sm in shard_meshes:
|
|
|
|
| 220 |
if (my_rank == sm).any().item():
|
| 221 |
+
key = _cache_key(sm)
|
| 222 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 223 |
my_key = key
|
| 224 |
+
if key not in _ranks_to_dist_cache:
|
| 225 |
+
pg = dist.new_group(sm.flatten().tolist(),
|
| 226 |
+
use_local_synchronization=True)
|
| 227 |
+
_ranks_to_dist_cache[key] = (
|
| 228 |
+
DeviceMesh(device_type="cuda", mesh=sm),
|
| 229 |
+
pg,
|
| 230 |
+
)
|
| 231 |
|
| 232 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
@@ -8,7 +8,7 @@ import torch.distributed as dist
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
-
from .adamw import step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
-
expert_keys=None
|
| 211 |
-
cpu_offload=False):
|
| 212 |
defaults = dict(
|
| 213 |
lr=lr,
|
| 214 |
weight_decay=weight_decay,
|
|
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 229 |
if param_group.get("use_muon", None) is None:
|
| 230 |
raise ValueError(
|
| 231 |
error_message.format(idx=_idx) + instruction_code)
|
| 232 |
-
|
| 233 |
super().__init__(params, defaults)
|
| 234 |
|
| 235 |
self.debug = debug
|
|
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 243 |
self.chunk_size = chunk_size
|
| 244 |
self.use_distributed_muon = use_distributed_muon
|
| 245 |
self.expert_keys = expert_keys
|
| 246 |
-
self.cpu_offload =
|
| 247 |
-
self._cpu_offload_pool
|
| 248 |
self._offload_initialized = False
|
| 249 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 250 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 1008 |
# D2H: offload optimizer states to CPU after computation.
|
| 1009 |
if self.cpu_offload:
|
| 1010 |
if not self._offload_initialized:
|
|
|
|
|
|
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
-
#
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
-
def
|
| 1022 |
-
"""
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1032 |
self._cpu_offload_pool.reload()
|
| 1033 |
torch.cuda.current_stream().synchronize()
|
| 1034 |
-
|
| 1035 |
-
|
| 1036 |
-
|
| 1037 |
-
# (which frees GPU storage on the originals via resize_(0)).
|
| 1038 |
-
for k in sd["state"]:
|
| 1039 |
-
sd["state"][k] = {
|
| 1040 |
-
sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
|
| 1041 |
-
for sk, sv in sd["state"][k].items()
|
| 1042 |
-
}
|
| 1043 |
-
self._cpu_offload_pool.offload()
|
| 1044 |
-
return sd
|
| 1045 |
|
| 1046 |
-
|
| 1047 |
-
|
|
|
|
| 1048 |
|
| 1049 |
-
|
| 1050 |
-
|
| 1051 |
-
|
| 1052 |
-
|
| 1053 |
-
# If states were offloaded, reload first so storage sizes are
|
| 1054 |
-
# correct for super().load_state_dict() to overwrite.
|
| 1055 |
-
if self.cpu_offload and self._offload_initialized:
|
| 1056 |
-
self._cpu_offload_pool.reload()
|
| 1057 |
-
torch.cuda.current_stream().synchronize()
|
| 1058 |
|
|
|
|
|
|
|
|
|
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
-
|
| 1064 |
-
|
| 1065 |
-
self._offload_initialized = False
|
| 1066 |
-
self._register_states_for_offload()
|
| 1067 |
-
self._offload_initialized = True
|
| 1068 |
-
self._cpu_offload_pool.offload()
|
|
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
+
from .adamw import _placement_cache, _tensor_cache, step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
+
expert_keys=None):
|
|
|
|
| 211 |
defaults = dict(
|
| 212 |
lr=lr,
|
| 213 |
weight_decay=weight_decay,
|
|
|
|
| 228 |
if param_group.get("use_muon", None) is None:
|
| 229 |
raise ValueError(
|
| 230 |
error_message.format(idx=_idx) + instruction_code)
|
|
|
|
| 231 |
super().__init__(params, defaults)
|
| 232 |
|
| 233 |
self.debug = debug
|
|
|
|
| 241 |
self.chunk_size = chunk_size
|
| 242 |
self.use_distributed_muon = use_distributed_muon
|
| 243 |
self.expert_keys = expert_keys
|
| 244 |
+
self.cpu_offload = False
|
| 245 |
+
self._cpu_offload_pool: CPUOffloadPool | None = None
|
| 246 |
self._offload_initialized = False
|
| 247 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 248 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
|
|
| 1006 |
# D2H: offload optimizer states to CPU after computation.
|
| 1007 |
if self.cpu_offload:
|
| 1008 |
if not self._offload_initialized:
|
| 1009 |
+
if self._cpu_offload_pool is None:
|
| 1010 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
+
# CPU offload public helpers
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
+
def turn_on_cpu_offload(self):
|
| 1022 |
+
"""Enable CPU offload for optimizer states."""
|
| 1023 |
+
if self.cpu_offload:
|
| 1024 |
+
return
|
| 1025 |
+
logger.info("[Muon] turn_on_cpu_offload")
|
| 1026 |
+
self.cpu_offload = True
|
| 1027 |
+
if not self.state:
|
| 1028 |
+
return
|
| 1029 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1030 |
+
self._offload_initialized = False
|
| 1031 |
+
self._register_states_for_offload()
|
| 1032 |
+
self._offload_initialized = True
|
| 1033 |
+
self._cpu_offload_pool.offload()
|
| 1034 |
+
|
| 1035 |
+
def turn_off_cpu_offload(self):
|
| 1036 |
+
"""Disable CPU offload and keep optimizer states resident on GPU."""
|
| 1037 |
+
if not self.cpu_offload:
|
| 1038 |
+
return
|
| 1039 |
+
logger.info("[Muon] turn_off_cpu_offload")
|
| 1040 |
+
if self._offload_initialized:
|
| 1041 |
self._cpu_offload_pool.reload()
|
| 1042 |
torch.cuda.current_stream().synchronize()
|
| 1043 |
+
self._cpu_offload_pool = None
|
| 1044 |
+
self._offload_initialized = False
|
| 1045 |
+
self.cpu_offload = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1046 |
|
| 1047 |
+
# ------------------------------------------------------------------
|
| 1048 |
+
# Checkpoint support for cpu_offload
|
| 1049 |
+
# ------------------------------------------------------------------
|
| 1050 |
|
| 1051 |
+
def state_dict(self) -> dict:
|
| 1052 |
+
if self.cpu_offload:
|
| 1053 |
+
raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
|
| 1054 |
+
return super().state_dict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1055 |
|
| 1056 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
| 1057 |
+
if self.cpu_offload:
|
| 1058 |
+
raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
+
# Invalidate adamw.py's module-level tensor caches so that
|
| 1062 |
+
# the next step rebuilds them with the newly loaded state tensors.
|
| 1063 |
+
_placement_cache.clear()
|
| 1064 |
+
_tensor_cache.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_b68ea5b_dirty
|
| 3 |
+
ops = torch.ops._optimizer_b68ea5b_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_b68ea5b_dirty::{op_name}"
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1936664
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:37bcda6973440cdeb880e411dbaf12220cef0bab18299b4922b6a504ab109b42
|
| 3 |
size 1936664
|
|
@@ -68,7 +68,11 @@ class CPUOffloadPool:
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
self._managed.append(tensor)
|
| 73 |
|
| 74 |
# ------------------------------------------------------------------
|
|
@@ -89,7 +93,10 @@ class CPUOffloadPool:
|
|
| 89 |
indices.append(idx)
|
| 90 |
offsets.append((off, n))
|
| 91 |
off += n
|
| 92 |
-
cpu_flat = torch.empty(off,
|
|
|
|
|
|
|
|
|
|
| 93 |
self._groups[dtype] = {
|
| 94 |
"indices": indices,
|
| 95 |
"offsets": offsets,
|
|
@@ -117,8 +124,7 @@ class CPUOffloadPool:
|
|
| 117 |
self._ensure_stream()
|
| 118 |
|
| 119 |
# Offload stream waits for compute to finish.
|
| 120 |
-
compute_event = torch.cuda.current_stream(
|
| 121 |
-
self._device).record_event()
|
| 122 |
self._offload_stream.wait_event(compute_event)
|
| 123 |
|
| 124 |
offloaded_bytes = 0
|
|
@@ -134,15 +140,23 @@ class CPUOffloadPool:
|
|
| 134 |
for i, mgd_idx in enumerate(indices):
|
| 135 |
local = self._local(self._managed[mgd_idx])
|
| 136 |
off, n = offsets[i]
|
| 137 |
-
cpu_flat[off:off + n].copy_(
|
| 138 |
-
|
| 139 |
|
| 140 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 141 |
|
| 142 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 143 |
self._offload_stream.synchronize()
|
| 144 |
for t in self._managed:
|
| 145 |
-
self._local(t).untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
if not self._logged:
|
| 148 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
@@ -165,7 +179,14 @@ class CPUOffloadPool:
|
|
| 165 |
# Re-allocate all GPU storages first.
|
| 166 |
for t in self._managed:
|
| 167 |
local = self._local(t)
|
| 168 |
-
local.untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 171 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
@@ -177,8 +198,8 @@ class CPUOffloadPool:
|
|
| 177 |
for i, mgd_idx in enumerate(indices):
|
| 178 |
local = self._local(self._managed[mgd_idx])
|
| 179 |
off, n = offsets[i]
|
| 180 |
-
local.reshape(-1).copy_(
|
| 181 |
-
|
| 182 |
|
| 183 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 184 |
|
|
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
+
storage = local.untyped_storage()
|
| 72 |
+
# Skip tensors with empty storage (e.g. empty FSDP shards)
|
| 73 |
+
if storage.size() == 0:
|
| 74 |
+
return
|
| 75 |
+
self._storage_nbytes[tid] = storage.size()
|
| 76 |
self._managed.append(tensor)
|
| 77 |
|
| 78 |
# ------------------------------------------------------------------
|
|
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
+
cpu_flat = torch.empty(off,
|
| 97 |
+
dtype=dtype,
|
| 98 |
+
device="cpu",
|
| 99 |
+
pin_memory=True)
|
| 100 |
self._groups[dtype] = {
|
| 101 |
"indices": indices,
|
| 102 |
"offsets": offsets,
|
|
|
|
| 124 |
self._ensure_stream()
|
| 125 |
|
| 126 |
# Offload stream waits for compute to finish.
|
| 127 |
+
compute_event = torch.cuda.current_stream(self._device).record_event()
|
|
|
|
| 128 |
self._offload_stream.wait_event(compute_event)
|
| 129 |
|
| 130 |
offloaded_bytes = 0
|
|
|
|
| 140 |
for i, mgd_idx in enumerate(indices):
|
| 141 |
local = self._local(self._managed[mgd_idx])
|
| 142 |
off, n = offsets[i]
|
| 143 |
+
cpu_flat[off:off + n].copy_(local.reshape(-1),
|
| 144 |
+
non_blocking=True)
|
| 145 |
|
| 146 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 147 |
|
| 148 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 149 |
self._offload_stream.synchronize()
|
| 150 |
for t in self._managed:
|
| 151 |
+
storage = self._local(t).untyped_storage()
|
| 152 |
+
if storage.size() != 0:
|
| 153 |
+
storage.resize_(0)
|
| 154 |
+
else:
|
| 155 |
+
raise RuntimeError(
|
| 156 |
+
f"Tensor storage is already freed (size=0) before offload. "
|
| 157 |
+
f"This indicates a double-free or external interference. "
|
| 158 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 159 |
+
)
|
| 160 |
|
| 161 |
if not self._logged:
|
| 162 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
|
|
| 179 |
# Re-allocate all GPU storages first.
|
| 180 |
for t in self._managed:
|
| 181 |
local = self._local(t)
|
| 182 |
+
storage = local.untyped_storage()
|
| 183 |
+
if storage.size() != 0:
|
| 184 |
+
raise RuntimeError(
|
| 185 |
+
f"Storage should have been freed (size=0) before reload, "
|
| 186 |
+
f"but got size={storage.size()}. "
|
| 187 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 188 |
+
)
|
| 189 |
+
storage.resize_(self._storage_nbytes[id(t)])
|
| 190 |
|
| 191 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 192 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
|
|
| 198 |
for i, mgd_idx in enumerate(indices):
|
| 199 |
local = self._local(self._managed[mgd_idx])
|
| 200 |
off, n = offsets[i]
|
| 201 |
+
local.reshape(-1).copy_(cpu_flat[off:off + n],
|
| 202 |
+
non_blocking=True)
|
| 203 |
|
| 204 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 205 |
|
|
@@ -163,9 +163,10 @@ def construct_shard_mesh(
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
-
#
|
| 167 |
-
#
|
| 168 |
-
#
|
|
|
|
| 169 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 170 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 171 |
if key not in _ranks_to_dist_cache:
|
|
@@ -207,22 +208,25 @@ def construct_shard_mesh(
|
|
| 207 |
assert len(shard_placements) == len(set(shard_placements))
|
| 208 |
|
| 209 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 210 |
-
#
|
| 211 |
-
#
|
|
|
|
|
|
|
| 212 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 213 |
return (*t.shape, *t.flatten().tolist())
|
| 214 |
|
| 215 |
my_key = None
|
| 216 |
for sm in shard_meshes:
|
| 217 |
-
key = _cache_key(sm)
|
| 218 |
if (my_rank == sm).any().item():
|
|
|
|
| 219 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 220 |
my_key = key
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
| 227 |
|
| 228 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
+
# Reuses the mesh's existing ProcessGroup directly, avoiding the
|
| 167 |
+
# overhead of dist.new_group(). The standard path below also handles
|
| 168 |
+
# subset calls safely via use_local_synchronization=True, but this
|
| 169 |
+
# fast path is still beneficial for the common 1D shard case.
|
| 170 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 171 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 172 |
if key not in _ranks_to_dist_cache:
|
|
|
|
| 208 |
assert len(shard_placements) == len(set(shard_placements))
|
| 209 |
|
| 210 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 211 |
+
# Each rank only creates the group it belongs to, using
|
| 212 |
+
# use_local_synchronization=True so that only group members need to
|
| 213 |
+
# coordinate. This avoids deadlocks when different PP stages call
|
| 214 |
+
# construct_shard_mesh for different parameters.
|
| 215 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 216 |
return (*t.shape, *t.flatten().tolist())
|
| 217 |
|
| 218 |
my_key = None
|
| 219 |
for sm in shard_meshes:
|
|
|
|
| 220 |
if (my_rank == sm).any().item():
|
| 221 |
+
key = _cache_key(sm)
|
| 222 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 223 |
my_key = key
|
| 224 |
+
if key not in _ranks_to_dist_cache:
|
| 225 |
+
pg = dist.new_group(sm.flatten().tolist(),
|
| 226 |
+
use_local_synchronization=True)
|
| 227 |
+
_ranks_to_dist_cache[key] = (
|
| 228 |
+
DeviceMesh(device_type="cuda", mesh=sm),
|
| 229 |
+
pg,
|
| 230 |
+
)
|
| 231 |
|
| 232 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
@@ -8,7 +8,7 @@ import torch.distributed as dist
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
-
from .adamw import step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
-
expert_keys=None
|
| 211 |
-
cpu_offload=False):
|
| 212 |
defaults = dict(
|
| 213 |
lr=lr,
|
| 214 |
weight_decay=weight_decay,
|
|
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 229 |
if param_group.get("use_muon", None) is None:
|
| 230 |
raise ValueError(
|
| 231 |
error_message.format(idx=_idx) + instruction_code)
|
| 232 |
-
|
| 233 |
super().__init__(params, defaults)
|
| 234 |
|
| 235 |
self.debug = debug
|
|
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 243 |
self.chunk_size = chunk_size
|
| 244 |
self.use_distributed_muon = use_distributed_muon
|
| 245 |
self.expert_keys = expert_keys
|
| 246 |
-
self.cpu_offload =
|
| 247 |
-
self._cpu_offload_pool
|
| 248 |
self._offload_initialized = False
|
| 249 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 250 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 1008 |
# D2H: offload optimizer states to CPU after computation.
|
| 1009 |
if self.cpu_offload:
|
| 1010 |
if not self._offload_initialized:
|
|
|
|
|
|
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
-
#
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
-
def
|
| 1022 |
-
"""
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1032 |
self._cpu_offload_pool.reload()
|
| 1033 |
torch.cuda.current_stream().synchronize()
|
| 1034 |
-
|
| 1035 |
-
|
| 1036 |
-
|
| 1037 |
-
# (which frees GPU storage on the originals via resize_(0)).
|
| 1038 |
-
for k in sd["state"]:
|
| 1039 |
-
sd["state"][k] = {
|
| 1040 |
-
sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
|
| 1041 |
-
for sk, sv in sd["state"][k].items()
|
| 1042 |
-
}
|
| 1043 |
-
self._cpu_offload_pool.offload()
|
| 1044 |
-
return sd
|
| 1045 |
|
| 1046 |
-
|
| 1047 |
-
|
|
|
|
| 1048 |
|
| 1049 |
-
|
| 1050 |
-
|
| 1051 |
-
|
| 1052 |
-
|
| 1053 |
-
# If states were offloaded, reload first so storage sizes are
|
| 1054 |
-
# correct for super().load_state_dict() to overwrite.
|
| 1055 |
-
if self.cpu_offload and self._offload_initialized:
|
| 1056 |
-
self._cpu_offload_pool.reload()
|
| 1057 |
-
torch.cuda.current_stream().synchronize()
|
| 1058 |
|
|
|
|
|
|
|
|
|
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
-
|
| 1064 |
-
|
| 1065 |
-
self._offload_initialized = False
|
| 1066 |
-
self._register_states_for_offload()
|
| 1067 |
-
self._offload_initialized = True
|
| 1068 |
-
self._cpu_offload_pool.offload()
|
|
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
+
from .adamw import _placement_cache, _tensor_cache, step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
+
expert_keys=None):
|
|
|
|
| 211 |
defaults = dict(
|
| 212 |
lr=lr,
|
| 213 |
weight_decay=weight_decay,
|
|
|
|
| 228 |
if param_group.get("use_muon", None) is None:
|
| 229 |
raise ValueError(
|
| 230 |
error_message.format(idx=_idx) + instruction_code)
|
|
|
|
| 231 |
super().__init__(params, defaults)
|
| 232 |
|
| 233 |
self.debug = debug
|
|
|
|
| 241 |
self.chunk_size = chunk_size
|
| 242 |
self.use_distributed_muon = use_distributed_muon
|
| 243 |
self.expert_keys = expert_keys
|
| 244 |
+
self.cpu_offload = False
|
| 245 |
+
self._cpu_offload_pool: CPUOffloadPool | None = None
|
| 246 |
self._offload_initialized = False
|
| 247 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 248 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
|
|
| 1006 |
# D2H: offload optimizer states to CPU after computation.
|
| 1007 |
if self.cpu_offload:
|
| 1008 |
if not self._offload_initialized:
|
| 1009 |
+
if self._cpu_offload_pool is None:
|
| 1010 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
+
# CPU offload public helpers
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
+
def turn_on_cpu_offload(self):
|
| 1022 |
+
"""Enable CPU offload for optimizer states."""
|
| 1023 |
+
if self.cpu_offload:
|
| 1024 |
+
return
|
| 1025 |
+
logger.info("[Muon] turn_on_cpu_offload")
|
| 1026 |
+
self.cpu_offload = True
|
| 1027 |
+
if not self.state:
|
| 1028 |
+
return
|
| 1029 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1030 |
+
self._offload_initialized = False
|
| 1031 |
+
self._register_states_for_offload()
|
| 1032 |
+
self._offload_initialized = True
|
| 1033 |
+
self._cpu_offload_pool.offload()
|
| 1034 |
+
|
| 1035 |
+
def turn_off_cpu_offload(self):
|
| 1036 |
+
"""Disable CPU offload and keep optimizer states resident on GPU."""
|
| 1037 |
+
if not self.cpu_offload:
|
| 1038 |
+
return
|
| 1039 |
+
logger.info("[Muon] turn_off_cpu_offload")
|
| 1040 |
+
if self._offload_initialized:
|
| 1041 |
self._cpu_offload_pool.reload()
|
| 1042 |
torch.cuda.current_stream().synchronize()
|
| 1043 |
+
self._cpu_offload_pool = None
|
| 1044 |
+
self._offload_initialized = False
|
| 1045 |
+
self.cpu_offload = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1046 |
|
| 1047 |
+
# ------------------------------------------------------------------
|
| 1048 |
+
# Checkpoint support for cpu_offload
|
| 1049 |
+
# ------------------------------------------------------------------
|
| 1050 |
|
| 1051 |
+
def state_dict(self) -> dict:
|
| 1052 |
+
if self.cpu_offload:
|
| 1053 |
+
raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
|
| 1054 |
+
return super().state_dict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1055 |
|
| 1056 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
| 1057 |
+
if self.cpu_offload:
|
| 1058 |
+
raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
+
# Invalidate adamw.py's module-level tensor caches so that
|
| 1062 |
+
# the next step rebuilds them with the newly loaded state tensors.
|
| 1063 |
+
_placement_cache.clear()
|
| 1064 |
+
_tensor_cache.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_b68ea5b_dirty
|
| 3 |
+
ops = torch.ops._optimizer_b68ea5b_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_b68ea5b_dirty::{op_name}"
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1999872
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e9cd35014d41034ed35fbad31c19c80e7b3977cea889a865eb12db705678bb29
|
| 3 |
size 1999872
|
|
@@ -68,7 +68,11 @@ class CPUOffloadPool:
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
self._managed.append(tensor)
|
| 73 |
|
| 74 |
# ------------------------------------------------------------------
|
|
@@ -89,7 +93,10 @@ class CPUOffloadPool:
|
|
| 89 |
indices.append(idx)
|
| 90 |
offsets.append((off, n))
|
| 91 |
off += n
|
| 92 |
-
cpu_flat = torch.empty(off,
|
|
|
|
|
|
|
|
|
|
| 93 |
self._groups[dtype] = {
|
| 94 |
"indices": indices,
|
| 95 |
"offsets": offsets,
|
|
@@ -117,8 +124,7 @@ class CPUOffloadPool:
|
|
| 117 |
self._ensure_stream()
|
| 118 |
|
| 119 |
# Offload stream waits for compute to finish.
|
| 120 |
-
compute_event = torch.cuda.current_stream(
|
| 121 |
-
self._device).record_event()
|
| 122 |
self._offload_stream.wait_event(compute_event)
|
| 123 |
|
| 124 |
offloaded_bytes = 0
|
|
@@ -134,15 +140,23 @@ class CPUOffloadPool:
|
|
| 134 |
for i, mgd_idx in enumerate(indices):
|
| 135 |
local = self._local(self._managed[mgd_idx])
|
| 136 |
off, n = offsets[i]
|
| 137 |
-
cpu_flat[off:off + n].copy_(
|
| 138 |
-
|
| 139 |
|
| 140 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 141 |
|
| 142 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 143 |
self._offload_stream.synchronize()
|
| 144 |
for t in self._managed:
|
| 145 |
-
self._local(t).untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
if not self._logged:
|
| 148 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
@@ -165,7 +179,14 @@ class CPUOffloadPool:
|
|
| 165 |
# Re-allocate all GPU storages first.
|
| 166 |
for t in self._managed:
|
| 167 |
local = self._local(t)
|
| 168 |
-
local.untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 171 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
@@ -177,8 +198,8 @@ class CPUOffloadPool:
|
|
| 177 |
for i, mgd_idx in enumerate(indices):
|
| 178 |
local = self._local(self._managed[mgd_idx])
|
| 179 |
off, n = offsets[i]
|
| 180 |
-
local.reshape(-1).copy_(
|
| 181 |
-
|
| 182 |
|
| 183 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 184 |
|
|
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
+
storage = local.untyped_storage()
|
| 72 |
+
# Skip tensors with empty storage (e.g. empty FSDP shards)
|
| 73 |
+
if storage.size() == 0:
|
| 74 |
+
return
|
| 75 |
+
self._storage_nbytes[tid] = storage.size()
|
| 76 |
self._managed.append(tensor)
|
| 77 |
|
| 78 |
# ------------------------------------------------------------------
|
|
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
+
cpu_flat = torch.empty(off,
|
| 97 |
+
dtype=dtype,
|
| 98 |
+
device="cpu",
|
| 99 |
+
pin_memory=True)
|
| 100 |
self._groups[dtype] = {
|
| 101 |
"indices": indices,
|
| 102 |
"offsets": offsets,
|
|
|
|
| 124 |
self._ensure_stream()
|
| 125 |
|
| 126 |
# Offload stream waits for compute to finish.
|
| 127 |
+
compute_event = torch.cuda.current_stream(self._device).record_event()
|
|
|
|
| 128 |
self._offload_stream.wait_event(compute_event)
|
| 129 |
|
| 130 |
offloaded_bytes = 0
|
|
|
|
| 140 |
for i, mgd_idx in enumerate(indices):
|
| 141 |
local = self._local(self._managed[mgd_idx])
|
| 142 |
off, n = offsets[i]
|
| 143 |
+
cpu_flat[off:off + n].copy_(local.reshape(-1),
|
| 144 |
+
non_blocking=True)
|
| 145 |
|
| 146 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 147 |
|
| 148 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 149 |
self._offload_stream.synchronize()
|
| 150 |
for t in self._managed:
|
| 151 |
+
storage = self._local(t).untyped_storage()
|
| 152 |
+
if storage.size() != 0:
|
| 153 |
+
storage.resize_(0)
|
| 154 |
+
else:
|
| 155 |
+
raise RuntimeError(
|
| 156 |
+
f"Tensor storage is already freed (size=0) before offload. "
|
| 157 |
+
f"This indicates a double-free or external interference. "
|
| 158 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 159 |
+
)
|
| 160 |
|
| 161 |
if not self._logged:
|
| 162 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
|
|
| 179 |
# Re-allocate all GPU storages first.
|
| 180 |
for t in self._managed:
|
| 181 |
local = self._local(t)
|
| 182 |
+
storage = local.untyped_storage()
|
| 183 |
+
if storage.size() != 0:
|
| 184 |
+
raise RuntimeError(
|
| 185 |
+
f"Storage should have been freed (size=0) before reload, "
|
| 186 |
+
f"but got size={storage.size()}. "
|
| 187 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 188 |
+
)
|
| 189 |
+
storage.resize_(self._storage_nbytes[id(t)])
|
| 190 |
|
| 191 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 192 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
|
|
| 198 |
for i, mgd_idx in enumerate(indices):
|
| 199 |
local = self._local(self._managed[mgd_idx])
|
| 200 |
off, n = offsets[i]
|
| 201 |
+
local.reshape(-1).copy_(cpu_flat[off:off + n],
|
| 202 |
+
non_blocking=True)
|
| 203 |
|
| 204 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 205 |
|
|
@@ -163,9 +163,10 @@ def construct_shard_mesh(
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
-
#
|
| 167 |
-
#
|
| 168 |
-
#
|
|
|
|
| 169 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 170 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 171 |
if key not in _ranks_to_dist_cache:
|
|
@@ -207,22 +208,25 @@ def construct_shard_mesh(
|
|
| 207 |
assert len(shard_placements) == len(set(shard_placements))
|
| 208 |
|
| 209 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 210 |
-
#
|
| 211 |
-
#
|
|
|
|
|
|
|
| 212 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 213 |
return (*t.shape, *t.flatten().tolist())
|
| 214 |
|
| 215 |
my_key = None
|
| 216 |
for sm in shard_meshes:
|
| 217 |
-
key = _cache_key(sm)
|
| 218 |
if (my_rank == sm).any().item():
|
|
|
|
| 219 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 220 |
my_key = key
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
| 227 |
|
| 228 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
+
# Reuses the mesh's existing ProcessGroup directly, avoiding the
|
| 167 |
+
# overhead of dist.new_group(). The standard path below also handles
|
| 168 |
+
# subset calls safely via use_local_synchronization=True, but this
|
| 169 |
+
# fast path is still beneficial for the common 1D shard case.
|
| 170 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 171 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 172 |
if key not in _ranks_to_dist_cache:
|
|
|
|
| 208 |
assert len(shard_placements) == len(set(shard_placements))
|
| 209 |
|
| 210 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 211 |
+
# Each rank only creates the group it belongs to, using
|
| 212 |
+
# use_local_synchronization=True so that only group members need to
|
| 213 |
+
# coordinate. This avoids deadlocks when different PP stages call
|
| 214 |
+
# construct_shard_mesh for different parameters.
|
| 215 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 216 |
return (*t.shape, *t.flatten().tolist())
|
| 217 |
|
| 218 |
my_key = None
|
| 219 |
for sm in shard_meshes:
|
|
|
|
| 220 |
if (my_rank == sm).any().item():
|
| 221 |
+
key = _cache_key(sm)
|
| 222 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 223 |
my_key = key
|
| 224 |
+
if key not in _ranks_to_dist_cache:
|
| 225 |
+
pg = dist.new_group(sm.flatten().tolist(),
|
| 226 |
+
use_local_synchronization=True)
|
| 227 |
+
_ranks_to_dist_cache[key] = (
|
| 228 |
+
DeviceMesh(device_type="cuda", mesh=sm),
|
| 229 |
+
pg,
|
| 230 |
+
)
|
| 231 |
|
| 232 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
@@ -8,7 +8,7 @@ import torch.distributed as dist
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
-
from .adamw import step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
-
expert_keys=None
|
| 211 |
-
cpu_offload=False):
|
| 212 |
defaults = dict(
|
| 213 |
lr=lr,
|
| 214 |
weight_decay=weight_decay,
|
|
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 229 |
if param_group.get("use_muon", None) is None:
|
| 230 |
raise ValueError(
|
| 231 |
error_message.format(idx=_idx) + instruction_code)
|
| 232 |
-
|
| 233 |
super().__init__(params, defaults)
|
| 234 |
|
| 235 |
self.debug = debug
|
|
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 243 |
self.chunk_size = chunk_size
|
| 244 |
self.use_distributed_muon = use_distributed_muon
|
| 245 |
self.expert_keys = expert_keys
|
| 246 |
-
self.cpu_offload =
|
| 247 |
-
self._cpu_offload_pool
|
| 248 |
self._offload_initialized = False
|
| 249 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 250 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 1008 |
# D2H: offload optimizer states to CPU after computation.
|
| 1009 |
if self.cpu_offload:
|
| 1010 |
if not self._offload_initialized:
|
|
|
|
|
|
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
-
#
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
-
def
|
| 1022 |
-
"""
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1032 |
self._cpu_offload_pool.reload()
|
| 1033 |
torch.cuda.current_stream().synchronize()
|
| 1034 |
-
|
| 1035 |
-
|
| 1036 |
-
|
| 1037 |
-
# (which frees GPU storage on the originals via resize_(0)).
|
| 1038 |
-
for k in sd["state"]:
|
| 1039 |
-
sd["state"][k] = {
|
| 1040 |
-
sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
|
| 1041 |
-
for sk, sv in sd["state"][k].items()
|
| 1042 |
-
}
|
| 1043 |
-
self._cpu_offload_pool.offload()
|
| 1044 |
-
return sd
|
| 1045 |
|
| 1046 |
-
|
| 1047 |
-
|
|
|
|
| 1048 |
|
| 1049 |
-
|
| 1050 |
-
|
| 1051 |
-
|
| 1052 |
-
|
| 1053 |
-
# If states were offloaded, reload first so storage sizes are
|
| 1054 |
-
# correct for super().load_state_dict() to overwrite.
|
| 1055 |
-
if self.cpu_offload and self._offload_initialized:
|
| 1056 |
-
self._cpu_offload_pool.reload()
|
| 1057 |
-
torch.cuda.current_stream().synchronize()
|
| 1058 |
|
|
|
|
|
|
|
|
|
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
-
|
| 1064 |
-
|
| 1065 |
-
self._offload_initialized = False
|
| 1066 |
-
self._register_states_for_offload()
|
| 1067 |
-
self._offload_initialized = True
|
| 1068 |
-
self._cpu_offload_pool.offload()
|
|
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
+
from .adamw import _placement_cache, _tensor_cache, step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
+
expert_keys=None):
|
|
|
|
| 211 |
defaults = dict(
|
| 212 |
lr=lr,
|
| 213 |
weight_decay=weight_decay,
|
|
|
|
| 228 |
if param_group.get("use_muon", None) is None:
|
| 229 |
raise ValueError(
|
| 230 |
error_message.format(idx=_idx) + instruction_code)
|
|
|
|
| 231 |
super().__init__(params, defaults)
|
| 232 |
|
| 233 |
self.debug = debug
|
|
|
|
| 241 |
self.chunk_size = chunk_size
|
| 242 |
self.use_distributed_muon = use_distributed_muon
|
| 243 |
self.expert_keys = expert_keys
|
| 244 |
+
self.cpu_offload = False
|
| 245 |
+
self._cpu_offload_pool: CPUOffloadPool | None = None
|
| 246 |
self._offload_initialized = False
|
| 247 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 248 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
|
|
| 1006 |
# D2H: offload optimizer states to CPU after computation.
|
| 1007 |
if self.cpu_offload:
|
| 1008 |
if not self._offload_initialized:
|
| 1009 |
+
if self._cpu_offload_pool is None:
|
| 1010 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
+
# CPU offload public helpers
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
+
def turn_on_cpu_offload(self):
|
| 1022 |
+
"""Enable CPU offload for optimizer states."""
|
| 1023 |
+
if self.cpu_offload:
|
| 1024 |
+
return
|
| 1025 |
+
logger.info("[Muon] turn_on_cpu_offload")
|
| 1026 |
+
self.cpu_offload = True
|
| 1027 |
+
if not self.state:
|
| 1028 |
+
return
|
| 1029 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1030 |
+
self._offload_initialized = False
|
| 1031 |
+
self._register_states_for_offload()
|
| 1032 |
+
self._offload_initialized = True
|
| 1033 |
+
self._cpu_offload_pool.offload()
|
| 1034 |
+
|
| 1035 |
+
def turn_off_cpu_offload(self):
|
| 1036 |
+
"""Disable CPU offload and keep optimizer states resident on GPU."""
|
| 1037 |
+
if not self.cpu_offload:
|
| 1038 |
+
return
|
| 1039 |
+
logger.info("[Muon] turn_off_cpu_offload")
|
| 1040 |
+
if self._offload_initialized:
|
| 1041 |
self._cpu_offload_pool.reload()
|
| 1042 |
torch.cuda.current_stream().synchronize()
|
| 1043 |
+
self._cpu_offload_pool = None
|
| 1044 |
+
self._offload_initialized = False
|
| 1045 |
+
self.cpu_offload = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1046 |
|
| 1047 |
+
# ------------------------------------------------------------------
|
| 1048 |
+
# Checkpoint support for cpu_offload
|
| 1049 |
+
# ------------------------------------------------------------------
|
| 1050 |
|
| 1051 |
+
def state_dict(self) -> dict:
|
| 1052 |
+
if self.cpu_offload:
|
| 1053 |
+
raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
|
| 1054 |
+
return super().state_dict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1055 |
|
| 1056 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
| 1057 |
+
if self.cpu_offload:
|
| 1058 |
+
raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
+
# Invalidate adamw.py's module-level tensor caches so that
|
| 1062 |
+
# the next step rebuilds them with the newly loaded state tensors.
|
| 1063 |
+
_placement_cache.clear()
|
| 1064 |
+
_tensor_cache.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_b68ea5b_dirty
|
| 3 |
+
ops = torch.ops._optimizer_b68ea5b_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_b68ea5b_dirty::{op_name}"
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1999872
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:64072037c62afbfefe92a35b026794f2bd406fddef38bf58d318d3bae7652a29
|
| 3 |
size 1999872
|
|
@@ -68,7 +68,11 @@ class CPUOffloadPool:
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
self._managed.append(tensor)
|
| 73 |
|
| 74 |
# ------------------------------------------------------------------
|
|
@@ -89,7 +93,10 @@ class CPUOffloadPool:
|
|
| 89 |
indices.append(idx)
|
| 90 |
offsets.append((off, n))
|
| 91 |
off += n
|
| 92 |
-
cpu_flat = torch.empty(off,
|
|
|
|
|
|
|
|
|
|
| 93 |
self._groups[dtype] = {
|
| 94 |
"indices": indices,
|
| 95 |
"offsets": offsets,
|
|
@@ -117,8 +124,7 @@ class CPUOffloadPool:
|
|
| 117 |
self._ensure_stream()
|
| 118 |
|
| 119 |
# Offload stream waits for compute to finish.
|
| 120 |
-
compute_event = torch.cuda.current_stream(
|
| 121 |
-
self._device).record_event()
|
| 122 |
self._offload_stream.wait_event(compute_event)
|
| 123 |
|
| 124 |
offloaded_bytes = 0
|
|
@@ -134,15 +140,23 @@ class CPUOffloadPool:
|
|
| 134 |
for i, mgd_idx in enumerate(indices):
|
| 135 |
local = self._local(self._managed[mgd_idx])
|
| 136 |
off, n = offsets[i]
|
| 137 |
-
cpu_flat[off:off + n].copy_(
|
| 138 |
-
|
| 139 |
|
| 140 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 141 |
|
| 142 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 143 |
self._offload_stream.synchronize()
|
| 144 |
for t in self._managed:
|
| 145 |
-
self._local(t).untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
if not self._logged:
|
| 148 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
@@ -165,7 +179,14 @@ class CPUOffloadPool:
|
|
| 165 |
# Re-allocate all GPU storages first.
|
| 166 |
for t in self._managed:
|
| 167 |
local = self._local(t)
|
| 168 |
-
local.untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 171 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
@@ -177,8 +198,8 @@ class CPUOffloadPool:
|
|
| 177 |
for i, mgd_idx in enumerate(indices):
|
| 178 |
local = self._local(self._managed[mgd_idx])
|
| 179 |
off, n = offsets[i]
|
| 180 |
-
local.reshape(-1).copy_(
|
| 181 |
-
|
| 182 |
|
| 183 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 184 |
|
|
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
+
storage = local.untyped_storage()
|
| 72 |
+
# Skip tensors with empty storage (e.g. empty FSDP shards)
|
| 73 |
+
if storage.size() == 0:
|
| 74 |
+
return
|
| 75 |
+
self._storage_nbytes[tid] = storage.size()
|
| 76 |
self._managed.append(tensor)
|
| 77 |
|
| 78 |
# ------------------------------------------------------------------
|
|
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
+
cpu_flat = torch.empty(off,
|
| 97 |
+
dtype=dtype,
|
| 98 |
+
device="cpu",
|
| 99 |
+
pin_memory=True)
|
| 100 |
self._groups[dtype] = {
|
| 101 |
"indices": indices,
|
| 102 |
"offsets": offsets,
|
|
|
|
| 124 |
self._ensure_stream()
|
| 125 |
|
| 126 |
# Offload stream waits for compute to finish.
|
| 127 |
+
compute_event = torch.cuda.current_stream(self._device).record_event()
|
|
|
|
| 128 |
self._offload_stream.wait_event(compute_event)
|
| 129 |
|
| 130 |
offloaded_bytes = 0
|
|
|
|
| 140 |
for i, mgd_idx in enumerate(indices):
|
| 141 |
local = self._local(self._managed[mgd_idx])
|
| 142 |
off, n = offsets[i]
|
| 143 |
+
cpu_flat[off:off + n].copy_(local.reshape(-1),
|
| 144 |
+
non_blocking=True)
|
| 145 |
|
| 146 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 147 |
|
| 148 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 149 |
self._offload_stream.synchronize()
|
| 150 |
for t in self._managed:
|
| 151 |
+
storage = self._local(t).untyped_storage()
|
| 152 |
+
if storage.size() != 0:
|
| 153 |
+
storage.resize_(0)
|
| 154 |
+
else:
|
| 155 |
+
raise RuntimeError(
|
| 156 |
+
f"Tensor storage is already freed (size=0) before offload. "
|
| 157 |
+
f"This indicates a double-free or external interference. "
|
| 158 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 159 |
+
)
|
| 160 |
|
| 161 |
if not self._logged:
|
| 162 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
|
|
| 179 |
# Re-allocate all GPU storages first.
|
| 180 |
for t in self._managed:
|
| 181 |
local = self._local(t)
|
| 182 |
+
storage = local.untyped_storage()
|
| 183 |
+
if storage.size() != 0:
|
| 184 |
+
raise RuntimeError(
|
| 185 |
+
f"Storage should have been freed (size=0) before reload, "
|
| 186 |
+
f"but got size={storage.size()}. "
|
| 187 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 188 |
+
)
|
| 189 |
+
storage.resize_(self._storage_nbytes[id(t)])
|
| 190 |
|
| 191 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 192 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
|
|
| 198 |
for i, mgd_idx in enumerate(indices):
|
| 199 |
local = self._local(self._managed[mgd_idx])
|
| 200 |
off, n = offsets[i]
|
| 201 |
+
local.reshape(-1).copy_(cpu_flat[off:off + n],
|
| 202 |
+
non_blocking=True)
|
| 203 |
|
| 204 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 205 |
|
|
@@ -163,9 +163,10 @@ def construct_shard_mesh(
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
-
#
|
| 167 |
-
#
|
| 168 |
-
#
|
|
|
|
| 169 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 170 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 171 |
if key not in _ranks_to_dist_cache:
|
|
@@ -207,22 +208,25 @@ def construct_shard_mesh(
|
|
| 207 |
assert len(shard_placements) == len(set(shard_placements))
|
| 208 |
|
| 209 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 210 |
-
#
|
| 211 |
-
#
|
|
|
|
|
|
|
| 212 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 213 |
return (*t.shape, *t.flatten().tolist())
|
| 214 |
|
| 215 |
my_key = None
|
| 216 |
for sm in shard_meshes:
|
| 217 |
-
key = _cache_key(sm)
|
| 218 |
if (my_rank == sm).any().item():
|
|
|
|
| 219 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 220 |
my_key = key
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
| 227 |
|
| 228 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
+
# Reuses the mesh's existing ProcessGroup directly, avoiding the
|
| 167 |
+
# overhead of dist.new_group(). The standard path below also handles
|
| 168 |
+
# subset calls safely via use_local_synchronization=True, but this
|
| 169 |
+
# fast path is still beneficial for the common 1D shard case.
|
| 170 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 171 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 172 |
if key not in _ranks_to_dist_cache:
|
|
|
|
| 208 |
assert len(shard_placements) == len(set(shard_placements))
|
| 209 |
|
| 210 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 211 |
+
# Each rank only creates the group it belongs to, using
|
| 212 |
+
# use_local_synchronization=True so that only group members need to
|
| 213 |
+
# coordinate. This avoids deadlocks when different PP stages call
|
| 214 |
+
# construct_shard_mesh for different parameters.
|
| 215 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 216 |
return (*t.shape, *t.flatten().tolist())
|
| 217 |
|
| 218 |
my_key = None
|
| 219 |
for sm in shard_meshes:
|
|
|
|
| 220 |
if (my_rank == sm).any().item():
|
| 221 |
+
key = _cache_key(sm)
|
| 222 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 223 |
my_key = key
|
| 224 |
+
if key not in _ranks_to_dist_cache:
|
| 225 |
+
pg = dist.new_group(sm.flatten().tolist(),
|
| 226 |
+
use_local_synchronization=True)
|
| 227 |
+
_ranks_to_dist_cache[key] = (
|
| 228 |
+
DeviceMesh(device_type="cuda", mesh=sm),
|
| 229 |
+
pg,
|
| 230 |
+
)
|
| 231 |
|
| 232 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
@@ -8,7 +8,7 @@ import torch.distributed as dist
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
-
from .adamw import step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
-
expert_keys=None
|
| 211 |
-
cpu_offload=False):
|
| 212 |
defaults = dict(
|
| 213 |
lr=lr,
|
| 214 |
weight_decay=weight_decay,
|
|
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 229 |
if param_group.get("use_muon", None) is None:
|
| 230 |
raise ValueError(
|
| 231 |
error_message.format(idx=_idx) + instruction_code)
|
| 232 |
-
|
| 233 |
super().__init__(params, defaults)
|
| 234 |
|
| 235 |
self.debug = debug
|
|
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 243 |
self.chunk_size = chunk_size
|
| 244 |
self.use_distributed_muon = use_distributed_muon
|
| 245 |
self.expert_keys = expert_keys
|
| 246 |
-
self.cpu_offload =
|
| 247 |
-
self._cpu_offload_pool
|
| 248 |
self._offload_initialized = False
|
| 249 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 250 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 1008 |
# D2H: offload optimizer states to CPU after computation.
|
| 1009 |
if self.cpu_offload:
|
| 1010 |
if not self._offload_initialized:
|
|
|
|
|
|
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
-
#
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
-
def
|
| 1022 |
-
"""
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1032 |
self._cpu_offload_pool.reload()
|
| 1033 |
torch.cuda.current_stream().synchronize()
|
| 1034 |
-
|
| 1035 |
-
|
| 1036 |
-
|
| 1037 |
-
# (which frees GPU storage on the originals via resize_(0)).
|
| 1038 |
-
for k in sd["state"]:
|
| 1039 |
-
sd["state"][k] = {
|
| 1040 |
-
sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
|
| 1041 |
-
for sk, sv in sd["state"][k].items()
|
| 1042 |
-
}
|
| 1043 |
-
self._cpu_offload_pool.offload()
|
| 1044 |
-
return sd
|
| 1045 |
|
| 1046 |
-
|
| 1047 |
-
|
|
|
|
| 1048 |
|
| 1049 |
-
|
| 1050 |
-
|
| 1051 |
-
|
| 1052 |
-
|
| 1053 |
-
# If states were offloaded, reload first so storage sizes are
|
| 1054 |
-
# correct for super().load_state_dict() to overwrite.
|
| 1055 |
-
if self.cpu_offload and self._offload_initialized:
|
| 1056 |
-
self._cpu_offload_pool.reload()
|
| 1057 |
-
torch.cuda.current_stream().synchronize()
|
| 1058 |
|
|
|
|
|
|
|
|
|
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
-
|
| 1064 |
-
|
| 1065 |
-
self._offload_initialized = False
|
| 1066 |
-
self._register_states_for_offload()
|
| 1067 |
-
self._offload_initialized = True
|
| 1068 |
-
self._cpu_offload_pool.offload()
|
|
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
+
from .adamw import _placement_cache, _tensor_cache, step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
+
expert_keys=None):
|
|
|
|
| 211 |
defaults = dict(
|
| 212 |
lr=lr,
|
| 213 |
weight_decay=weight_decay,
|
|
|
|
| 228 |
if param_group.get("use_muon", None) is None:
|
| 229 |
raise ValueError(
|
| 230 |
error_message.format(idx=_idx) + instruction_code)
|
|
|
|
| 231 |
super().__init__(params, defaults)
|
| 232 |
|
| 233 |
self.debug = debug
|
|
|
|
| 241 |
self.chunk_size = chunk_size
|
| 242 |
self.use_distributed_muon = use_distributed_muon
|
| 243 |
self.expert_keys = expert_keys
|
| 244 |
+
self.cpu_offload = False
|
| 245 |
+
self._cpu_offload_pool: CPUOffloadPool | None = None
|
| 246 |
self._offload_initialized = False
|
| 247 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 248 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
|
|
| 1006 |
# D2H: offload optimizer states to CPU after computation.
|
| 1007 |
if self.cpu_offload:
|
| 1008 |
if not self._offload_initialized:
|
| 1009 |
+
if self._cpu_offload_pool is None:
|
| 1010 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
+
# CPU offload public helpers
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
+
def turn_on_cpu_offload(self):
|
| 1022 |
+
"""Enable CPU offload for optimizer states."""
|
| 1023 |
+
if self.cpu_offload:
|
| 1024 |
+
return
|
| 1025 |
+
logger.info("[Muon] turn_on_cpu_offload")
|
| 1026 |
+
self.cpu_offload = True
|
| 1027 |
+
if not self.state:
|
| 1028 |
+
return
|
| 1029 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1030 |
+
self._offload_initialized = False
|
| 1031 |
+
self._register_states_for_offload()
|
| 1032 |
+
self._offload_initialized = True
|
| 1033 |
+
self._cpu_offload_pool.offload()
|
| 1034 |
+
|
| 1035 |
+
def turn_off_cpu_offload(self):
|
| 1036 |
+
"""Disable CPU offload and keep optimizer states resident on GPU."""
|
| 1037 |
+
if not self.cpu_offload:
|
| 1038 |
+
return
|
| 1039 |
+
logger.info("[Muon] turn_off_cpu_offload")
|
| 1040 |
+
if self._offload_initialized:
|
| 1041 |
self._cpu_offload_pool.reload()
|
| 1042 |
torch.cuda.current_stream().synchronize()
|
| 1043 |
+
self._cpu_offload_pool = None
|
| 1044 |
+
self._offload_initialized = False
|
| 1045 |
+
self.cpu_offload = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1046 |
|
| 1047 |
+
# ------------------------------------------------------------------
|
| 1048 |
+
# Checkpoint support for cpu_offload
|
| 1049 |
+
# ------------------------------------------------------------------
|
| 1050 |
|
| 1051 |
+
def state_dict(self) -> dict:
|
| 1052 |
+
if self.cpu_offload:
|
| 1053 |
+
raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
|
| 1054 |
+
return super().state_dict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1055 |
|
| 1056 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
| 1057 |
+
if self.cpu_offload:
|
| 1058 |
+
raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
+
# Invalidate adamw.py's module-level tensor caches so that
|
| 1062 |
+
# the next step rebuilds them with the newly loaded state tensors.
|
| 1063 |
+
_placement_cache.clear()
|
| 1064 |
+
_tensor_cache.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_b68ea5b_dirty
|
| 3 |
+
ops = torch.ops._optimizer_b68ea5b_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_b68ea5b_dirty::{op_name}"
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1865080
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:65135fe756ed97f2bd21fefda883b6a7b90179ebd7c0a882673239daf9d9aa6a
|
| 3 |
size 1865080
|
|
@@ -68,7 +68,11 @@ class CPUOffloadPool:
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
self._managed.append(tensor)
|
| 73 |
|
| 74 |
# ------------------------------------------------------------------
|
|
@@ -89,7 +93,10 @@ class CPUOffloadPool:
|
|
| 89 |
indices.append(idx)
|
| 90 |
offsets.append((off, n))
|
| 91 |
off += n
|
| 92 |
-
cpu_flat = torch.empty(off,
|
|
|
|
|
|
|
|
|
|
| 93 |
self._groups[dtype] = {
|
| 94 |
"indices": indices,
|
| 95 |
"offsets": offsets,
|
|
@@ -117,8 +124,7 @@ class CPUOffloadPool:
|
|
| 117 |
self._ensure_stream()
|
| 118 |
|
| 119 |
# Offload stream waits for compute to finish.
|
| 120 |
-
compute_event = torch.cuda.current_stream(
|
| 121 |
-
self._device).record_event()
|
| 122 |
self._offload_stream.wait_event(compute_event)
|
| 123 |
|
| 124 |
offloaded_bytes = 0
|
|
@@ -134,15 +140,23 @@ class CPUOffloadPool:
|
|
| 134 |
for i, mgd_idx in enumerate(indices):
|
| 135 |
local = self._local(self._managed[mgd_idx])
|
| 136 |
off, n = offsets[i]
|
| 137 |
-
cpu_flat[off:off + n].copy_(
|
| 138 |
-
|
| 139 |
|
| 140 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 141 |
|
| 142 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 143 |
self._offload_stream.synchronize()
|
| 144 |
for t in self._managed:
|
| 145 |
-
self._local(t).untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
if not self._logged:
|
| 148 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
@@ -165,7 +179,14 @@ class CPUOffloadPool:
|
|
| 165 |
# Re-allocate all GPU storages first.
|
| 166 |
for t in self._managed:
|
| 167 |
local = self._local(t)
|
| 168 |
-
local.untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 171 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
@@ -177,8 +198,8 @@ class CPUOffloadPool:
|
|
| 177 |
for i, mgd_idx in enumerate(indices):
|
| 178 |
local = self._local(self._managed[mgd_idx])
|
| 179 |
off, n = offsets[i]
|
| 180 |
-
local.reshape(-1).copy_(
|
| 181 |
-
|
| 182 |
|
| 183 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 184 |
|
|
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
+
storage = local.untyped_storage()
|
| 72 |
+
# Skip tensors with empty storage (e.g. empty FSDP shards)
|
| 73 |
+
if storage.size() == 0:
|
| 74 |
+
return
|
| 75 |
+
self._storage_nbytes[tid] = storage.size()
|
| 76 |
self._managed.append(tensor)
|
| 77 |
|
| 78 |
# ------------------------------------------------------------------
|
|
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
+
cpu_flat = torch.empty(off,
|
| 97 |
+
dtype=dtype,
|
| 98 |
+
device="cpu",
|
| 99 |
+
pin_memory=True)
|
| 100 |
self._groups[dtype] = {
|
| 101 |
"indices": indices,
|
| 102 |
"offsets": offsets,
|
|
|
|
| 124 |
self._ensure_stream()
|
| 125 |
|
| 126 |
# Offload stream waits for compute to finish.
|
| 127 |
+
compute_event = torch.cuda.current_stream(self._device).record_event()
|
|
|
|
| 128 |
self._offload_stream.wait_event(compute_event)
|
| 129 |
|
| 130 |
offloaded_bytes = 0
|
|
|
|
| 140 |
for i, mgd_idx in enumerate(indices):
|
| 141 |
local = self._local(self._managed[mgd_idx])
|
| 142 |
off, n = offsets[i]
|
| 143 |
+
cpu_flat[off:off + n].copy_(local.reshape(-1),
|
| 144 |
+
non_blocking=True)
|
| 145 |
|
| 146 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 147 |
|
| 148 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 149 |
self._offload_stream.synchronize()
|
| 150 |
for t in self._managed:
|
| 151 |
+
storage = self._local(t).untyped_storage()
|
| 152 |
+
if storage.size() != 0:
|
| 153 |
+
storage.resize_(0)
|
| 154 |
+
else:
|
| 155 |
+
raise RuntimeError(
|
| 156 |
+
f"Tensor storage is already freed (size=0) before offload. "
|
| 157 |
+
f"This indicates a double-free or external interference. "
|
| 158 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 159 |
+
)
|
| 160 |
|
| 161 |
if not self._logged:
|
| 162 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
|
|
| 179 |
# Re-allocate all GPU storages first.
|
| 180 |
for t in self._managed:
|
| 181 |
local = self._local(t)
|
| 182 |
+
storage = local.untyped_storage()
|
| 183 |
+
if storage.size() != 0:
|
| 184 |
+
raise RuntimeError(
|
| 185 |
+
f"Storage should have been freed (size=0) before reload, "
|
| 186 |
+
f"but got size={storage.size()}. "
|
| 187 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 188 |
+
)
|
| 189 |
+
storage.resize_(self._storage_nbytes[id(t)])
|
| 190 |
|
| 191 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 192 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
|
|
| 198 |
for i, mgd_idx in enumerate(indices):
|
| 199 |
local = self._local(self._managed[mgd_idx])
|
| 200 |
off, n = offsets[i]
|
| 201 |
+
local.reshape(-1).copy_(cpu_flat[off:off + n],
|
| 202 |
+
non_blocking=True)
|
| 203 |
|
| 204 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 205 |
|
|
@@ -163,9 +163,10 @@ def construct_shard_mesh(
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
-
#
|
| 167 |
-
#
|
| 168 |
-
#
|
|
|
|
| 169 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 170 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 171 |
if key not in _ranks_to_dist_cache:
|
|
@@ -207,22 +208,25 @@ def construct_shard_mesh(
|
|
| 207 |
assert len(shard_placements) == len(set(shard_placements))
|
| 208 |
|
| 209 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 210 |
-
#
|
| 211 |
-
#
|
|
|
|
|
|
|
| 212 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 213 |
return (*t.shape, *t.flatten().tolist())
|
| 214 |
|
| 215 |
my_key = None
|
| 216 |
for sm in shard_meshes:
|
| 217 |
-
key = _cache_key(sm)
|
| 218 |
if (my_rank == sm).any().item():
|
|
|
|
| 219 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 220 |
my_key = key
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
| 227 |
|
| 228 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
+
# Reuses the mesh's existing ProcessGroup directly, avoiding the
|
| 167 |
+
# overhead of dist.new_group(). The standard path below also handles
|
| 168 |
+
# subset calls safely via use_local_synchronization=True, but this
|
| 169 |
+
# fast path is still beneficial for the common 1D shard case.
|
| 170 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 171 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 172 |
if key not in _ranks_to_dist_cache:
|
|
|
|
| 208 |
assert len(shard_placements) == len(set(shard_placements))
|
| 209 |
|
| 210 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 211 |
+
# Each rank only creates the group it belongs to, using
|
| 212 |
+
# use_local_synchronization=True so that only group members need to
|
| 213 |
+
# coordinate. This avoids deadlocks when different PP stages call
|
| 214 |
+
# construct_shard_mesh for different parameters.
|
| 215 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 216 |
return (*t.shape, *t.flatten().tolist())
|
| 217 |
|
| 218 |
my_key = None
|
| 219 |
for sm in shard_meshes:
|
|
|
|
| 220 |
if (my_rank == sm).any().item():
|
| 221 |
+
key = _cache_key(sm)
|
| 222 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 223 |
my_key = key
|
| 224 |
+
if key not in _ranks_to_dist_cache:
|
| 225 |
+
pg = dist.new_group(sm.flatten().tolist(),
|
| 226 |
+
use_local_synchronization=True)
|
| 227 |
+
_ranks_to_dist_cache[key] = (
|
| 228 |
+
DeviceMesh(device_type="cuda", mesh=sm),
|
| 229 |
+
pg,
|
| 230 |
+
)
|
| 231 |
|
| 232 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
@@ -8,7 +8,7 @@ import torch.distributed as dist
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
-
from .adamw import step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
-
expert_keys=None
|
| 211 |
-
cpu_offload=False):
|
| 212 |
defaults = dict(
|
| 213 |
lr=lr,
|
| 214 |
weight_decay=weight_decay,
|
|
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 229 |
if param_group.get("use_muon", None) is None:
|
| 230 |
raise ValueError(
|
| 231 |
error_message.format(idx=_idx) + instruction_code)
|
| 232 |
-
|
| 233 |
super().__init__(params, defaults)
|
| 234 |
|
| 235 |
self.debug = debug
|
|
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 243 |
self.chunk_size = chunk_size
|
| 244 |
self.use_distributed_muon = use_distributed_muon
|
| 245 |
self.expert_keys = expert_keys
|
| 246 |
-
self.cpu_offload =
|
| 247 |
-
self._cpu_offload_pool
|
| 248 |
self._offload_initialized = False
|
| 249 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 250 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 1008 |
# D2H: offload optimizer states to CPU after computation.
|
| 1009 |
if self.cpu_offload:
|
| 1010 |
if not self._offload_initialized:
|
|
|
|
|
|
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
-
#
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
-
def
|
| 1022 |
-
"""
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1032 |
self._cpu_offload_pool.reload()
|
| 1033 |
torch.cuda.current_stream().synchronize()
|
| 1034 |
-
|
| 1035 |
-
|
| 1036 |
-
|
| 1037 |
-
# (which frees GPU storage on the originals via resize_(0)).
|
| 1038 |
-
for k in sd["state"]:
|
| 1039 |
-
sd["state"][k] = {
|
| 1040 |
-
sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
|
| 1041 |
-
for sk, sv in sd["state"][k].items()
|
| 1042 |
-
}
|
| 1043 |
-
self._cpu_offload_pool.offload()
|
| 1044 |
-
return sd
|
| 1045 |
|
| 1046 |
-
|
| 1047 |
-
|
|
|
|
| 1048 |
|
| 1049 |
-
|
| 1050 |
-
|
| 1051 |
-
|
| 1052 |
-
|
| 1053 |
-
# If states were offloaded, reload first so storage sizes are
|
| 1054 |
-
# correct for super().load_state_dict() to overwrite.
|
| 1055 |
-
if self.cpu_offload and self._offload_initialized:
|
| 1056 |
-
self._cpu_offload_pool.reload()
|
| 1057 |
-
torch.cuda.current_stream().synchronize()
|
| 1058 |
|
|
|
|
|
|
|
|
|
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
-
|
| 1064 |
-
|
| 1065 |
-
self._offload_initialized = False
|
| 1066 |
-
self._register_states_for_offload()
|
| 1067 |
-
self._offload_initialized = True
|
| 1068 |
-
self._cpu_offload_pool.offload()
|
|
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
+
from .adamw import _placement_cache, _tensor_cache, step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
+
expert_keys=None):
|
|
|
|
| 211 |
defaults = dict(
|
| 212 |
lr=lr,
|
| 213 |
weight_decay=weight_decay,
|
|
|
|
| 228 |
if param_group.get("use_muon", None) is None:
|
| 229 |
raise ValueError(
|
| 230 |
error_message.format(idx=_idx) + instruction_code)
|
|
|
|
| 231 |
super().__init__(params, defaults)
|
| 232 |
|
| 233 |
self.debug = debug
|
|
|
|
| 241 |
self.chunk_size = chunk_size
|
| 242 |
self.use_distributed_muon = use_distributed_muon
|
| 243 |
self.expert_keys = expert_keys
|
| 244 |
+
self.cpu_offload = False
|
| 245 |
+
self._cpu_offload_pool: CPUOffloadPool | None = None
|
| 246 |
self._offload_initialized = False
|
| 247 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 248 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
|
|
| 1006 |
# D2H: offload optimizer states to CPU after computation.
|
| 1007 |
if self.cpu_offload:
|
| 1008 |
if not self._offload_initialized:
|
| 1009 |
+
if self._cpu_offload_pool is None:
|
| 1010 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
+
# CPU offload public helpers
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
+
def turn_on_cpu_offload(self):
|
| 1022 |
+
"""Enable CPU offload for optimizer states."""
|
| 1023 |
+
if self.cpu_offload:
|
| 1024 |
+
return
|
| 1025 |
+
logger.info("[Muon] turn_on_cpu_offload")
|
| 1026 |
+
self.cpu_offload = True
|
| 1027 |
+
if not self.state:
|
| 1028 |
+
return
|
| 1029 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1030 |
+
self._offload_initialized = False
|
| 1031 |
+
self._register_states_for_offload()
|
| 1032 |
+
self._offload_initialized = True
|
| 1033 |
+
self._cpu_offload_pool.offload()
|
| 1034 |
+
|
| 1035 |
+
def turn_off_cpu_offload(self):
|
| 1036 |
+
"""Disable CPU offload and keep optimizer states resident on GPU."""
|
| 1037 |
+
if not self.cpu_offload:
|
| 1038 |
+
return
|
| 1039 |
+
logger.info("[Muon] turn_off_cpu_offload")
|
| 1040 |
+
if self._offload_initialized:
|
| 1041 |
self._cpu_offload_pool.reload()
|
| 1042 |
torch.cuda.current_stream().synchronize()
|
| 1043 |
+
self._cpu_offload_pool = None
|
| 1044 |
+
self._offload_initialized = False
|
| 1045 |
+
self.cpu_offload = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1046 |
|
| 1047 |
+
# ------------------------------------------------------------------
|
| 1048 |
+
# Checkpoint support for cpu_offload
|
| 1049 |
+
# ------------------------------------------------------------------
|
| 1050 |
|
| 1051 |
+
def state_dict(self) -> dict:
|
| 1052 |
+
if self.cpu_offload:
|
| 1053 |
+
raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
|
| 1054 |
+
return super().state_dict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1055 |
|
| 1056 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
| 1057 |
+
if self.cpu_offload:
|
| 1058 |
+
raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
+
# Invalidate adamw.py's module-level tensor caches so that
|
| 1062 |
+
# the next step rebuilds them with the newly loaded state tensors.
|
| 1063 |
+
_placement_cache.clear()
|
| 1064 |
+
_tensor_cache.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_b68ea5b_dirty
|
| 3 |
+
ops = torch.ops._optimizer_b68ea5b_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_b68ea5b_dirty::{op_name}"
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1865168
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c562a3b7e6f3032ff56473531e9e08fceb2c86f8804080330c896ab8f0dd32af
|
| 3 |
size 1865168
|
|
@@ -68,7 +68,11 @@ class CPUOffloadPool:
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
self._managed.append(tensor)
|
| 73 |
|
| 74 |
# ------------------------------------------------------------------
|
|
@@ -89,7 +93,10 @@ class CPUOffloadPool:
|
|
| 89 |
indices.append(idx)
|
| 90 |
offsets.append((off, n))
|
| 91 |
off += n
|
| 92 |
-
cpu_flat = torch.empty(off,
|
|
|
|
|
|
|
|
|
|
| 93 |
self._groups[dtype] = {
|
| 94 |
"indices": indices,
|
| 95 |
"offsets": offsets,
|
|
@@ -117,8 +124,7 @@ class CPUOffloadPool:
|
|
| 117 |
self._ensure_stream()
|
| 118 |
|
| 119 |
# Offload stream waits for compute to finish.
|
| 120 |
-
compute_event = torch.cuda.current_stream(
|
| 121 |
-
self._device).record_event()
|
| 122 |
self._offload_stream.wait_event(compute_event)
|
| 123 |
|
| 124 |
offloaded_bytes = 0
|
|
@@ -134,15 +140,23 @@ class CPUOffloadPool:
|
|
| 134 |
for i, mgd_idx in enumerate(indices):
|
| 135 |
local = self._local(self._managed[mgd_idx])
|
| 136 |
off, n = offsets[i]
|
| 137 |
-
cpu_flat[off:off + n].copy_(
|
| 138 |
-
|
| 139 |
|
| 140 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 141 |
|
| 142 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 143 |
self._offload_stream.synchronize()
|
| 144 |
for t in self._managed:
|
| 145 |
-
self._local(t).untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
if not self._logged:
|
| 148 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
@@ -165,7 +179,14 @@ class CPUOffloadPool:
|
|
| 165 |
# Re-allocate all GPU storages first.
|
| 166 |
for t in self._managed:
|
| 167 |
local = self._local(t)
|
| 168 |
-
local.untyped_storage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 171 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
@@ -177,8 +198,8 @@ class CPUOffloadPool:
|
|
| 177 |
for i, mgd_idx in enumerate(indices):
|
| 178 |
local = self._local(self._managed[mgd_idx])
|
| 179 |
off, n = offsets[i]
|
| 180 |
-
local.reshape(-1).copy_(
|
| 181 |
-
|
| 182 |
|
| 183 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 184 |
|
|
|
|
| 68 |
local = self._local(tensor)
|
| 69 |
if self._device is None:
|
| 70 |
self._device = local.device
|
| 71 |
+
storage = local.untyped_storage()
|
| 72 |
+
# Skip tensors with empty storage (e.g. empty FSDP shards)
|
| 73 |
+
if storage.size() == 0:
|
| 74 |
+
return
|
| 75 |
+
self._storage_nbytes[tid] = storage.size()
|
| 76 |
self._managed.append(tensor)
|
| 77 |
|
| 78 |
# ------------------------------------------------------------------
|
|
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
+
cpu_flat = torch.empty(off,
|
| 97 |
+
dtype=dtype,
|
| 98 |
+
device="cpu",
|
| 99 |
+
pin_memory=True)
|
| 100 |
self._groups[dtype] = {
|
| 101 |
"indices": indices,
|
| 102 |
"offsets": offsets,
|
|
|
|
| 124 |
self._ensure_stream()
|
| 125 |
|
| 126 |
# Offload stream waits for compute to finish.
|
| 127 |
+
compute_event = torch.cuda.current_stream(self._device).record_event()
|
|
|
|
| 128 |
self._offload_stream.wait_event(compute_event)
|
| 129 |
|
| 130 |
offloaded_bytes = 0
|
|
|
|
| 140 |
for i, mgd_idx in enumerate(indices):
|
| 141 |
local = self._local(self._managed[mgd_idx])
|
| 142 |
off, n = offsets[i]
|
| 143 |
+
cpu_flat[off:off + n].copy_(local.reshape(-1),
|
| 144 |
+
non_blocking=True)
|
| 145 |
|
| 146 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 147 |
|
| 148 |
# Wait for all D2H copies to land, then free GPU storage.
|
| 149 |
self._offload_stream.synchronize()
|
| 150 |
for t in self._managed:
|
| 151 |
+
storage = self._local(t).untyped_storage()
|
| 152 |
+
if storage.size() != 0:
|
| 153 |
+
storage.resize_(0)
|
| 154 |
+
else:
|
| 155 |
+
raise RuntimeError(
|
| 156 |
+
f"Tensor storage is already freed (size=0) before offload. "
|
| 157 |
+
f"This indicates a double-free or external interference. "
|
| 158 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 159 |
+
)
|
| 160 |
|
| 161 |
if not self._logged:
|
| 162 |
logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
|
|
|
|
| 179 |
# Re-allocate all GPU storages first.
|
| 180 |
for t in self._managed:
|
| 181 |
local = self._local(t)
|
| 182 |
+
storage = local.untyped_storage()
|
| 183 |
+
if storage.size() != 0:
|
| 184 |
+
raise RuntimeError(
|
| 185 |
+
f"Storage should have been freed (size=0) before reload, "
|
| 186 |
+
f"but got size={storage.size()}. "
|
| 187 |
+
f"Tensor shape: {t.shape}, dtype: {t.dtype}"
|
| 188 |
+
)
|
| 189 |
+
storage.resize_(self._storage_nbytes[id(t)])
|
| 190 |
|
| 191 |
# Per-tensor H2D copies from CPU flat buffer slices.
|
| 192 |
# non_blocking=True with pinned source allows DMA overlap.
|
|
|
|
| 198 |
for i, mgd_idx in enumerate(indices):
|
| 199 |
local = self._local(self._managed[mgd_idx])
|
| 200 |
off, n = offsets[i]
|
| 201 |
+
local.reshape(-1).copy_(cpu_flat[off:off + n],
|
| 202 |
+
non_blocking=True)
|
| 203 |
|
| 204 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 205 |
|
|
@@ -163,9 +163,10 @@ def construct_shard_mesh(
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
-
#
|
| 167 |
-
#
|
| 168 |
-
#
|
|
|
|
| 169 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 170 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 171 |
if key not in _ranks_to_dist_cache:
|
|
@@ -207,22 +208,25 @@ def construct_shard_mesh(
|
|
| 207 |
assert len(shard_placements) == len(set(shard_placements))
|
| 208 |
|
| 209 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 210 |
-
#
|
| 211 |
-
#
|
|
|
|
|
|
|
| 212 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 213 |
return (*t.shape, *t.flatten().tolist())
|
| 214 |
|
| 215 |
my_key = None
|
| 216 |
for sm in shard_meshes:
|
| 217 |
-
key = _cache_key(sm)
|
| 218 |
if (my_rank == sm).any().item():
|
|
|
|
| 219 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 220 |
my_key = key
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
| 227 |
|
| 228 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
|
|
| 163 |
assert mesh.mesh.device.type == 'cpu'
|
| 164 |
|
| 165 |
# -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
|
| 166 |
+
# Reuses the mesh's existing ProcessGroup directly, avoiding the
|
| 167 |
+
# overhead of dist.new_group(). The standard path below also handles
|
| 168 |
+
# subset calls safely via use_local_synchronization=True, but this
|
| 169 |
+
# fast path is still beneficial for the common 1D shard case.
|
| 170 |
if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
|
| 171 |
key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
|
| 172 |
if key not in _ranks_to_dist_cache:
|
|
|
|
| 208 |
assert len(shard_placements) == len(set(shard_placements))
|
| 209 |
|
| 210 |
# -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
|
| 211 |
+
# Each rank only creates the group it belongs to, using
|
| 212 |
+
# use_local_synchronization=True so that only group members need to
|
| 213 |
+
# coordinate. This avoids deadlocks when different PP stages call
|
| 214 |
+
# construct_shard_mesh for different parameters.
|
| 215 |
def _cache_key(t: torch.Tensor) -> tuple:
|
| 216 |
return (*t.shape, *t.flatten().tolist())
|
| 217 |
|
| 218 |
my_key = None
|
| 219 |
for sm in shard_meshes:
|
|
|
|
| 220 |
if (my_rank == sm).any().item():
|
| 221 |
+
key = _cache_key(sm)
|
| 222 |
assert my_key is None, "Rank appears in multiple shard groups"
|
| 223 |
my_key = key
|
| 224 |
+
if key not in _ranks_to_dist_cache:
|
| 225 |
+
pg = dist.new_group(sm.flatten().tolist(),
|
| 226 |
+
use_local_synchronization=True)
|
| 227 |
+
_ranks_to_dist_cache[key] = (
|
| 228 |
+
DeviceMesh(device_type="cuda", mesh=sm),
|
| 229 |
+
pg,
|
| 230 |
+
)
|
| 231 |
|
| 232 |
return (*_ranks_to_dist_cache[my_key], shard_placements)
|
|
@@ -8,7 +8,7 @@ import torch.distributed as dist
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
-
from .adamw import step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
-
expert_keys=None
|
| 211 |
-
cpu_offload=False):
|
| 212 |
defaults = dict(
|
| 213 |
lr=lr,
|
| 214 |
weight_decay=weight_decay,
|
|
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 229 |
if param_group.get("use_muon", None) is None:
|
| 230 |
raise ValueError(
|
| 231 |
error_message.format(idx=_idx) + instruction_code)
|
| 232 |
-
|
| 233 |
super().__init__(params, defaults)
|
| 234 |
|
| 235 |
self.debug = debug
|
|
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 243 |
self.chunk_size = chunk_size
|
| 244 |
self.use_distributed_muon = use_distributed_muon
|
| 245 |
self.expert_keys = expert_keys
|
| 246 |
-
self.cpu_offload =
|
| 247 |
-
self._cpu_offload_pool
|
| 248 |
self._offload_initialized = False
|
| 249 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 250 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 1008 |
# D2H: offload optimizer states to CPU after computation.
|
| 1009 |
if self.cpu_offload:
|
| 1010 |
if not self._offload_initialized:
|
|
|
|
|
|
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
-
#
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
-
def
|
| 1022 |
-
"""
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1032 |
self._cpu_offload_pool.reload()
|
| 1033 |
torch.cuda.current_stream().synchronize()
|
| 1034 |
-
|
| 1035 |
-
|
| 1036 |
-
|
| 1037 |
-
# (which frees GPU storage on the originals via resize_(0)).
|
| 1038 |
-
for k in sd["state"]:
|
| 1039 |
-
sd["state"][k] = {
|
| 1040 |
-
sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
|
| 1041 |
-
for sk, sv in sd["state"][k].items()
|
| 1042 |
-
}
|
| 1043 |
-
self._cpu_offload_pool.offload()
|
| 1044 |
-
return sd
|
| 1045 |
|
| 1046 |
-
|
| 1047 |
-
|
|
|
|
| 1048 |
|
| 1049 |
-
|
| 1050 |
-
|
| 1051 |
-
|
| 1052 |
-
|
| 1053 |
-
# If states were offloaded, reload first so storage sizes are
|
| 1054 |
-
# correct for super().load_state_dict() to overwrite.
|
| 1055 |
-
if self.cpu_offload and self._offload_initialized:
|
| 1056 |
-
self._cpu_offload_pool.reload()
|
| 1057 |
-
torch.cuda.current_stream().synchronize()
|
| 1058 |
|
|
|
|
|
|
|
|
|
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
-
|
| 1064 |
-
|
| 1065 |
-
self._offload_initialized = False
|
| 1066 |
-
self._register_states_for_offload()
|
| 1067 |
-
self._offload_initialized = True
|
| 1068 |
-
self._cpu_offload_pool.offload()
|
|
|
|
| 8 |
from torch.distributed.tensor import DTensor, Replicate, Shard
|
| 9 |
from torch.profiler import record_function
|
| 10 |
|
| 11 |
+
from .adamw import _placement_cache, _tensor_cache, step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
|
| 14 |
get_default_muon_param_groups, is_expert_param, update_p)
|
|
|
|
| 207 |
warmup_step=5,
|
| 208 |
chunk_size=-1,
|
| 209 |
use_distributed_muon=False,
|
| 210 |
+
expert_keys=None):
|
|
|
|
| 211 |
defaults = dict(
|
| 212 |
lr=lr,
|
| 213 |
weight_decay=weight_decay,
|
|
|
|
| 228 |
if param_group.get("use_muon", None) is None:
|
| 229 |
raise ValueError(
|
| 230 |
error_message.format(idx=_idx) + instruction_code)
|
|
|
|
| 231 |
super().__init__(params, defaults)
|
| 232 |
|
| 233 |
self.debug = debug
|
|
|
|
| 241 |
self.chunk_size = chunk_size
|
| 242 |
self.use_distributed_muon = use_distributed_muon
|
| 243 |
self.expert_keys = expert_keys
|
| 244 |
+
self.cpu_offload = False
|
| 245 |
+
self._cpu_offload_pool: CPUOffloadPool | None = None
|
| 246 |
self._offload_initialized = False
|
| 247 |
self._parallel_cache: dict[tuple[str, ...], dict] = {}
|
| 248 |
self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
|
|
|
|
| 1006 |
# D2H: offload optimizer states to CPU after computation.
|
| 1007 |
if self.cpu_offload:
|
| 1008 |
if not self._offload_initialized:
|
| 1009 |
+
if self._cpu_offload_pool is None:
|
| 1010 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1011 |
self._register_states_for_offload()
|
| 1012 |
self._offload_initialized = True
|
| 1013 |
self._cpu_offload_pool.offload()
|
|
|
|
| 1015 |
return loss
|
| 1016 |
|
| 1017 |
# ------------------------------------------------------------------
|
| 1018 |
+
# CPU offload public helpers
|
| 1019 |
# ------------------------------------------------------------------
|
| 1020 |
|
| 1021 |
+
def turn_on_cpu_offload(self):
|
| 1022 |
+
"""Enable CPU offload for optimizer states."""
|
| 1023 |
+
if self.cpu_offload:
|
| 1024 |
+
return
|
| 1025 |
+
logger.info("[Muon] turn_on_cpu_offload")
|
| 1026 |
+
self.cpu_offload = True
|
| 1027 |
+
if not self.state:
|
| 1028 |
+
return
|
| 1029 |
+
self._cpu_offload_pool = CPUOffloadPool()
|
| 1030 |
+
self._offload_initialized = False
|
| 1031 |
+
self._register_states_for_offload()
|
| 1032 |
+
self._offload_initialized = True
|
| 1033 |
+
self._cpu_offload_pool.offload()
|
| 1034 |
+
|
| 1035 |
+
def turn_off_cpu_offload(self):
|
| 1036 |
+
"""Disable CPU offload and keep optimizer states resident on GPU."""
|
| 1037 |
+
if not self.cpu_offload:
|
| 1038 |
+
return
|
| 1039 |
+
logger.info("[Muon] turn_off_cpu_offload")
|
| 1040 |
+
if self._offload_initialized:
|
| 1041 |
self._cpu_offload_pool.reload()
|
| 1042 |
torch.cuda.current_stream().synchronize()
|
| 1043 |
+
self._cpu_offload_pool = None
|
| 1044 |
+
self._offload_initialized = False
|
| 1045 |
+
self.cpu_offload = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1046 |
|
| 1047 |
+
# ------------------------------------------------------------------
|
| 1048 |
+
# Checkpoint support for cpu_offload
|
| 1049 |
+
# ------------------------------------------------------------------
|
| 1050 |
|
| 1051 |
+
def state_dict(self) -> dict:
|
| 1052 |
+
if self.cpu_offload:
|
| 1053 |
+
raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
|
| 1054 |
+
return super().state_dict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1055 |
|
| 1056 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
| 1057 |
+
if self.cpu_offload:
|
| 1058 |
+
raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
+
# Invalidate adamw.py's module-level tensor caches so that
|
| 1062 |
+
# the next step rebuilds them with the newly loaded state tensors.
|
| 1063 |
+
_placement_cache.clear()
|
| 1064 |
+
_tensor_cache.clear()
|
|
|
|
|
|
|
|
|
|
|
|