Kernels
wyldecat Claude Opus 4.6 (1M context) github-actions[bot] commited on
Commit
05a75f1
·
unverified ·
1 Parent(s): b61425a

Replace cpu_offload constructor param with turn_on/turn_off API (#26)

Browse files

* Replace cpu_offload constructor param with turn_on/turn_off API [skip-build]

Remove the cpu_offload boolean from Muon.__init__ and add explicit
turn_on_cpu_offload() / turn_off_cpu_offload() methods instead.
state_dict and load_state_dict now require offload to be disabled
first (RuntimeError). Preserves AdamW tensor cache invalidation
on load_state_dict. Adds test_toggle_correctness test.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix: skip empty FSDP shards in CPU offload and add storage validation

Empty FSDP shards (storage size=0) were being tracked by CPUOffloadPool,
causing double-free errors on resize_(0) during offload. This led to
hangs on ranks 1-7 while rank 0 succeeded.

- Skip tensors with empty storage in track() to avoid double-free
- Add storage size validation in offload() and reload() with RuntimeError
- Add logging for turn_on/turn_off_cpu_offload

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Add built binary [skip-build]

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch210-cxx11-cu126-x86_64-linux/_ops.py +3 -3
  2. build/torch210-cxx11-cu126-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
  3. build/torch210-cxx11-cu126-x86_64-linux/cpu_offload.py +31 -10
  4. build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py +16 -12
  5. build/torch210-cxx11-cu126-x86_64-linux/muon.py +44 -48
  6. build/torch210-cxx11-cu128-x86_64-linux/_ops.py +3 -3
  7. build/torch210-cxx11-cu128-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
  8. build/torch210-cxx11-cu128-x86_64-linux/cpu_offload.py +31 -10
  9. build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py +16 -12
  10. build/torch210-cxx11-cu128-x86_64-linux/muon.py +44 -48
  11. build/torch210-cxx11-cu130-x86_64-linux/_ops.py +3 -3
  12. build/torch210-cxx11-cu130-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
  13. build/torch210-cxx11-cu130-x86_64-linux/cpu_offload.py +31 -10
  14. build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py +16 -12
  15. build/torch210-cxx11-cu130-x86_64-linux/muon.py +44 -48
  16. build/torch210-cxx11-rocm70-x86_64-linux/_ops.py +3 -3
  17. build/torch210-cxx11-rocm70-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
  18. build/torch210-cxx11-rocm70-x86_64-linux/cpu_offload.py +31 -10
  19. build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py +16 -12
  20. build/torch210-cxx11-rocm70-x86_64-linux/muon.py +44 -48
  21. build/torch210-cxx11-rocm71-x86_64-linux/_ops.py +3 -3
  22. build/torch210-cxx11-rocm71-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
  23. build/torch210-cxx11-rocm71-x86_64-linux/cpu_offload.py +31 -10
  24. build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py +16 -12
  25. build/torch210-cxx11-rocm71-x86_64-linux/muon.py +44 -48
  26. build/torch28-cxx11-cu126-x86_64-linux/_ops.py +3 -3
  27. build/torch28-cxx11-cu126-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
  28. build/torch28-cxx11-cu126-x86_64-linux/cpu_offload.py +31 -10
  29. build/torch28-cxx11-cu126-x86_64-linux/distributed/utils.py +16 -12
  30. build/torch28-cxx11-cu126-x86_64-linux/muon.py +44 -48
  31. build/torch28-cxx11-cu128-x86_64-linux/_ops.py +3 -3
  32. build/torch28-cxx11-cu128-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
  33. build/torch28-cxx11-cu128-x86_64-linux/cpu_offload.py +31 -10
  34. build/torch28-cxx11-cu128-x86_64-linux/distributed/utils.py +16 -12
  35. build/torch28-cxx11-cu128-x86_64-linux/muon.py +44 -48
  36. build/torch28-cxx11-cu129-x86_64-linux/_ops.py +3 -3
  37. build/torch28-cxx11-cu129-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
  38. build/torch28-cxx11-cu129-x86_64-linux/cpu_offload.py +31 -10
  39. build/torch28-cxx11-cu129-x86_64-linux/distributed/utils.py +16 -12
  40. build/torch28-cxx11-cu129-x86_64-linux/muon.py +44 -48
  41. build/torch28-cxx11-rocm63-x86_64-linux/_ops.py +3 -3
  42. build/torch28-cxx11-rocm63-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
  43. build/torch28-cxx11-rocm63-x86_64-linux/cpu_offload.py +31 -10
  44. build/torch28-cxx11-rocm63-x86_64-linux/distributed/utils.py +16 -12
  45. build/torch28-cxx11-rocm63-x86_64-linux/muon.py +44 -48
  46. build/torch28-cxx11-rocm64-x86_64-linux/_ops.py +3 -3
  47. build/torch28-cxx11-rocm64-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} +1 -1
  48. build/torch28-cxx11-rocm64-x86_64-linux/cpu_offload.py +31 -10
  49. build/torch28-cxx11-rocm64-x86_64-linux/distributed/utils.py +16 -12
  50. build/torch28-cxx11-rocm64-x86_64-linux/muon.py +44 -48
build/torch210-cxx11-cu126-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_5b58933_dirty
3
- ops = torch.ops._optimizer_5b58933_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_5b58933_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_b68ea5b_dirty
3
+ ops = torch.ops._optimizer_b68ea5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_b68ea5b_dirty::{op_name}"
build/torch210-cxx11-cu126-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:90ace47a61519aefe759810c803789e7f91e6949ca0b04fc177e311709976334
3
  size 1940944
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7be82307f66be4bb841072ecdb3d105dc73bc9ee9ca21b1ce33bddc24113f4d1
3
  size 1940944
build/torch210-cxx11-cu126-x86_64-linux/cpu_offload.py CHANGED
@@ -68,7 +68,11 @@ class CPUOffloadPool:
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
- self._storage_nbytes[tid] = local.untyped_storage().size()
 
 
 
 
72
  self._managed.append(tensor)
73
 
74
  # ------------------------------------------------------------------
@@ -89,7 +93,10 @@ class CPUOffloadPool:
89
  indices.append(idx)
90
  offsets.append((off, n))
91
  off += n
92
- cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
 
 
 
93
  self._groups[dtype] = {
94
  "indices": indices,
95
  "offsets": offsets,
@@ -117,8 +124,7 @@ class CPUOffloadPool:
117
  self._ensure_stream()
118
 
119
  # Offload stream waits for compute to finish.
120
- compute_event = torch.cuda.current_stream(
121
- self._device).record_event()
122
  self._offload_stream.wait_event(compute_event)
123
 
124
  offloaded_bytes = 0
@@ -134,15 +140,23 @@ class CPUOffloadPool:
134
  for i, mgd_idx in enumerate(indices):
135
  local = self._local(self._managed[mgd_idx])
136
  off, n = offsets[i]
137
- cpu_flat[off:off + n].copy_(
138
- local.reshape(-1), non_blocking=True)
139
 
140
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
141
 
142
  # Wait for all D2H copies to land, then free GPU storage.
143
  self._offload_stream.synchronize()
144
  for t in self._managed:
145
- self._local(t).untyped_storage().resize_(0)
 
 
 
 
 
 
 
 
146
 
147
  if not self._logged:
148
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
@@ -165,7 +179,14 @@ class CPUOffloadPool:
165
  # Re-allocate all GPU storages first.
166
  for t in self._managed:
167
  local = self._local(t)
168
- local.untyped_storage().resize_(self._storage_nbytes[id(t)])
 
 
 
 
 
 
 
169
 
170
  # Per-tensor H2D copies from CPU flat buffer slices.
171
  # non_blocking=True with pinned source allows DMA overlap.
@@ -177,8 +198,8 @@ class CPUOffloadPool:
177
  for i, mgd_idx in enumerate(indices):
178
  local = self._local(self._managed[mgd_idx])
179
  off, n = offsets[i]
180
- local.reshape(-1).copy_(
181
- cpu_flat[off:off + n], non_blocking=True)
182
 
183
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
184
 
 
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
+ storage = local.untyped_storage()
72
+ # Skip tensors with empty storage (e.g. empty FSDP shards)
73
+ if storage.size() == 0:
74
+ return
75
+ self._storage_nbytes[tid] = storage.size()
76
  self._managed.append(tensor)
77
 
78
  # ------------------------------------------------------------------
 
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
+ cpu_flat = torch.empty(off,
97
+ dtype=dtype,
98
+ device="cpu",
99
+ pin_memory=True)
100
  self._groups[dtype] = {
101
  "indices": indices,
102
  "offsets": offsets,
 
124
  self._ensure_stream()
125
 
126
  # Offload stream waits for compute to finish.
127
+ compute_event = torch.cuda.current_stream(self._device).record_event()
 
128
  self._offload_stream.wait_event(compute_event)
129
 
130
  offloaded_bytes = 0
 
140
  for i, mgd_idx in enumerate(indices):
141
  local = self._local(self._managed[mgd_idx])
142
  off, n = offsets[i]
143
+ cpu_flat[off:off + n].copy_(local.reshape(-1),
144
+ non_blocking=True)
145
 
146
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
147
 
148
  # Wait for all D2H copies to land, then free GPU storage.
149
  self._offload_stream.synchronize()
150
  for t in self._managed:
151
+ storage = self._local(t).untyped_storage()
152
+ if storage.size() != 0:
153
+ storage.resize_(0)
154
+ else:
155
+ raise RuntimeError(
156
+ f"Tensor storage is already freed (size=0) before offload. "
157
+ f"This indicates a double-free or external interference. "
158
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
159
+ )
160
 
161
  if not self._logged:
162
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
 
179
  # Re-allocate all GPU storages first.
180
  for t in self._managed:
181
  local = self._local(t)
182
+ storage = local.untyped_storage()
183
+ if storage.size() != 0:
184
+ raise RuntimeError(
185
+ f"Storage should have been freed (size=0) before reload, "
186
+ f"but got size={storage.size()}. "
187
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
188
+ )
189
+ storage.resize_(self._storage_nbytes[id(t)])
190
 
191
  # Per-tensor H2D copies from CPU flat buffer slices.
192
  # non_blocking=True with pinned source allows DMA overlap.
 
198
  for i, mgd_idx in enumerate(indices):
199
  local = self._local(self._managed[mgd_idx])
200
  off, n = offsets[i]
201
+ local.reshape(-1).copy_(cpu_flat[off:off + n],
202
+ non_blocking=True)
203
 
204
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
205
 
build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py CHANGED
@@ -163,9 +163,10 @@ def construct_shard_mesh(
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
- # This avoids a non-collective dist.new_group() call, which would
167
- # deadlock when only a subset of ranks call this function (e.g. expert
168
- # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately).
 
169
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
170
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
171
  if key not in _ranks_to_dist_cache:
@@ -207,22 +208,25 @@ def construct_shard_mesh(
207
  assert len(shard_placements) == len(set(shard_placements))
208
 
209
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
210
- # All ranks must call dist.new_group in the same order, even though each
211
- # rank only joins one group.
 
 
212
  def _cache_key(t: torch.Tensor) -> tuple:
213
  return (*t.shape, *t.flatten().tolist())
214
 
215
  my_key = None
216
  for sm in shard_meshes:
217
- key = _cache_key(sm)
218
  if (my_rank == sm).any().item():
 
219
  assert my_key is None, "Rank appears in multiple shard groups"
220
  my_key = key
221
- if key not in _ranks_to_dist_cache:
222
- pg = dist.new_group(sm.flatten().tolist())
223
- _ranks_to_dist_cache[key] = (
224
- DeviceMesh(device_type="cuda", mesh=sm),
225
- pg,
226
- )
 
227
 
228
  return (*_ranks_to_dist_cache[my_key], shard_placements)
 
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
+ # Reuses the mesh's existing ProcessGroup directly, avoiding the
167
+ # overhead of dist.new_group(). The standard path below also handles
168
+ # subset calls safely via use_local_synchronization=True, but this
169
+ # fast path is still beneficial for the common 1D shard case.
170
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
171
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
172
  if key not in _ranks_to_dist_cache:
 
208
  assert len(shard_placements) == len(set(shard_placements))
209
 
210
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
211
+ # Each rank only creates the group it belongs to, using
212
+ # use_local_synchronization=True so that only group members need to
213
+ # coordinate. This avoids deadlocks when different PP stages call
214
+ # construct_shard_mesh for different parameters.
215
  def _cache_key(t: torch.Tensor) -> tuple:
216
  return (*t.shape, *t.flatten().tolist())
217
 
218
  my_key = None
219
  for sm in shard_meshes:
 
220
  if (my_rank == sm).any().item():
221
+ key = _cache_key(sm)
222
  assert my_key is None, "Rank appears in multiple shard groups"
223
  my_key = key
224
+ if key not in _ranks_to_dist_cache:
225
+ pg = dist.new_group(sm.flatten().tolist(),
226
+ use_local_synchronization=True)
227
+ _ranks_to_dist_cache[key] = (
228
+ DeviceMesh(device_type="cuda", mesh=sm),
229
+ pg,
230
+ )
231
 
232
  return (*_ranks_to_dist_cache[my_key], shard_placements)
build/torch210-cxx11-cu126-x86_64-linux/muon.py CHANGED
@@ -8,7 +8,7 @@ import torch.distributed as dist
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
- from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
- expert_keys=None,
211
- cpu_offload=False):
212
  defaults = dict(
213
  lr=lr,
214
  weight_decay=weight_decay,
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
229
  if param_group.get("use_muon", None) is None:
230
  raise ValueError(
231
  error_message.format(idx=_idx) + instruction_code)
232
-
233
  super().__init__(params, defaults)
234
 
235
  self.debug = debug
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
243
  self.chunk_size = chunk_size
244
  self.use_distributed_muon = use_distributed_muon
245
  self.expert_keys = expert_keys
246
- self.cpu_offload = cpu_offload
247
- self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None
248
  self._offload_initialized = False
249
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
250
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
1008
  # D2H: offload optimizer states to CPU after computation.
1009
  if self.cpu_offload:
1010
  if not self._offload_initialized:
 
 
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
- # Checkpoint support for cpu_offload
1019
  # ------------------------------------------------------------------
1020
 
1021
- def state_dict(self) -> dict:
1022
- """Return optimizer state dict, reloading offloaded states first.
1023
-
1024
- When ``cpu_offload=True``, optimizer state tensors have their GPU
1025
- storage freed (``resize_(0)``) between steps. We reload them,
1026
- snapshot the state dict, then re-offload so the optimizer stays
1027
- in the expected post-step state. The returned dict holds cloned
1028
- tensors so they remain valid after the re-offload frees the
1029
- originals' GPU storage.
1030
- """
1031
- if self.cpu_offload and self._offload_initialized:
 
 
 
 
 
 
 
 
 
1032
  self._cpu_offload_pool.reload()
1033
  torch.cuda.current_stream().synchronize()
1034
- sd = super().state_dict()
1035
- if self.cpu_offload and self._offload_initialized:
1036
- # Clone state tensors so the returned dict survives re-offload
1037
- # (which frees GPU storage on the originals via resize_(0)).
1038
- for k in sd["state"]:
1039
- sd["state"][k] = {
1040
- sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
1041
- for sk, sv in sd["state"][k].items()
1042
- }
1043
- self._cpu_offload_pool.offload()
1044
- return sd
1045
 
1046
- def load_state_dict(self, state_dict: dict) -> None:
1047
- """Load optimizer state dict, then offload states if needed.
 
1048
 
1049
- After ``super().load_state_dict()`` populates GPU tensors, we
1050
- re-register them with the offload pool and offload to CPU so the
1051
- optimizer is in the same post-step state (GPU storage freed).
1052
- """
1053
- # If states were offloaded, reload first so storage sizes are
1054
- # correct for super().load_state_dict() to overwrite.
1055
- if self.cpu_offload and self._offload_initialized:
1056
- self._cpu_offload_pool.reload()
1057
- torch.cuda.current_stream().synchronize()
1058
 
 
 
 
1059
  super().load_state_dict(state_dict)
1060
 
1061
- if self.cpu_offload:
1062
- # Re-create the offload pool since state tensors may be new
1063
- # objects after load_state_dict.
1064
- self._cpu_offload_pool = CPUOffloadPool()
1065
- self._offload_initialized = False
1066
- self._register_states_for_offload()
1067
- self._offload_initialized = True
1068
- self._cpu_offload_pool.offload()
 
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
+ from .adamw import _placement_cache, _tensor_cache, step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
 
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
+ expert_keys=None):
 
211
  defaults = dict(
212
  lr=lr,
213
  weight_decay=weight_decay,
 
228
  if param_group.get("use_muon", None) is None:
229
  raise ValueError(
230
  error_message.format(idx=_idx) + instruction_code)
 
231
  super().__init__(params, defaults)
232
 
233
  self.debug = debug
 
241
  self.chunk_size = chunk_size
242
  self.use_distributed_muon = use_distributed_muon
243
  self.expert_keys = expert_keys
244
+ self.cpu_offload = False
245
+ self._cpu_offload_pool: CPUOffloadPool | None = None
246
  self._offload_initialized = False
247
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
248
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
 
1006
  # D2H: offload optimizer states to CPU after computation.
1007
  if self.cpu_offload:
1008
  if not self._offload_initialized:
1009
+ if self._cpu_offload_pool is None:
1010
+ self._cpu_offload_pool = CPUOffloadPool()
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
 
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
+ # CPU offload public helpers
1019
  # ------------------------------------------------------------------
1020
 
1021
+ def turn_on_cpu_offload(self):
1022
+ """Enable CPU offload for optimizer states."""
1023
+ if self.cpu_offload:
1024
+ return
1025
+ logger.info("[Muon] turn_on_cpu_offload")
1026
+ self.cpu_offload = True
1027
+ if not self.state:
1028
+ return
1029
+ self._cpu_offload_pool = CPUOffloadPool()
1030
+ self._offload_initialized = False
1031
+ self._register_states_for_offload()
1032
+ self._offload_initialized = True
1033
+ self._cpu_offload_pool.offload()
1034
+
1035
+ def turn_off_cpu_offload(self):
1036
+ """Disable CPU offload and keep optimizer states resident on GPU."""
1037
+ if not self.cpu_offload:
1038
+ return
1039
+ logger.info("[Muon] turn_off_cpu_offload")
1040
+ if self._offload_initialized:
1041
  self._cpu_offload_pool.reload()
1042
  torch.cuda.current_stream().synchronize()
1043
+ self._cpu_offload_pool = None
1044
+ self._offload_initialized = False
1045
+ self.cpu_offload = False
 
 
 
 
 
 
 
 
1046
 
1047
+ # ------------------------------------------------------------------
1048
+ # Checkpoint support for cpu_offload
1049
+ # ------------------------------------------------------------------
1050
 
1051
+ def state_dict(self) -> dict:
1052
+ if self.cpu_offload:
1053
+ raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
1054
+ return super().state_dict()
 
 
 
 
 
1055
 
1056
+ def load_state_dict(self, state_dict: dict) -> None:
1057
+ if self.cpu_offload:
1058
+ raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
1059
  super().load_state_dict(state_dict)
1060
 
1061
+ # Invalidate adamw.py's module-level tensor caches so that
1062
+ # the next step rebuilds them with the newly loaded state tensors.
1063
+ _placement_cache.clear()
1064
+ _tensor_cache.clear()
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_5b58933_dirty
3
- ops = torch.ops._optimizer_5b58933_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_5b58933_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_b68ea5b_dirty
3
+ ops = torch.ops._optimizer_b68ea5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_b68ea5b_dirty::{op_name}"
build/torch210-cxx11-cu128-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1abfa69cd254e0000246a074c0bfa53c2e72bb53cc5fa8216275295cd021c57a
3
  size 2004144
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fa10afe7c505f69ccf7a98aca116b6d551b9577ecce2dab2559c6c3b433be20
3
  size 2004144
build/torch210-cxx11-cu128-x86_64-linux/cpu_offload.py CHANGED
@@ -68,7 +68,11 @@ class CPUOffloadPool:
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
- self._storage_nbytes[tid] = local.untyped_storage().size()
 
 
 
 
72
  self._managed.append(tensor)
73
 
74
  # ------------------------------------------------------------------
@@ -89,7 +93,10 @@ class CPUOffloadPool:
89
  indices.append(idx)
90
  offsets.append((off, n))
91
  off += n
92
- cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
 
 
 
93
  self._groups[dtype] = {
94
  "indices": indices,
95
  "offsets": offsets,
@@ -117,8 +124,7 @@ class CPUOffloadPool:
117
  self._ensure_stream()
118
 
119
  # Offload stream waits for compute to finish.
120
- compute_event = torch.cuda.current_stream(
121
- self._device).record_event()
122
  self._offload_stream.wait_event(compute_event)
123
 
124
  offloaded_bytes = 0
@@ -134,15 +140,23 @@ class CPUOffloadPool:
134
  for i, mgd_idx in enumerate(indices):
135
  local = self._local(self._managed[mgd_idx])
136
  off, n = offsets[i]
137
- cpu_flat[off:off + n].copy_(
138
- local.reshape(-1), non_blocking=True)
139
 
140
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
141
 
142
  # Wait for all D2H copies to land, then free GPU storage.
143
  self._offload_stream.synchronize()
144
  for t in self._managed:
145
- self._local(t).untyped_storage().resize_(0)
 
 
 
 
 
 
 
 
146
 
147
  if not self._logged:
148
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
@@ -165,7 +179,14 @@ class CPUOffloadPool:
165
  # Re-allocate all GPU storages first.
166
  for t in self._managed:
167
  local = self._local(t)
168
- local.untyped_storage().resize_(self._storage_nbytes[id(t)])
 
 
 
 
 
 
 
169
 
170
  # Per-tensor H2D copies from CPU flat buffer slices.
171
  # non_blocking=True with pinned source allows DMA overlap.
@@ -177,8 +198,8 @@ class CPUOffloadPool:
177
  for i, mgd_idx in enumerate(indices):
178
  local = self._local(self._managed[mgd_idx])
179
  off, n = offsets[i]
180
- local.reshape(-1).copy_(
181
- cpu_flat[off:off + n], non_blocking=True)
182
 
183
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
184
 
 
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
+ storage = local.untyped_storage()
72
+ # Skip tensors with empty storage (e.g. empty FSDP shards)
73
+ if storage.size() == 0:
74
+ return
75
+ self._storage_nbytes[tid] = storage.size()
76
  self._managed.append(tensor)
77
 
78
  # ------------------------------------------------------------------
 
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
+ cpu_flat = torch.empty(off,
97
+ dtype=dtype,
98
+ device="cpu",
99
+ pin_memory=True)
100
  self._groups[dtype] = {
101
  "indices": indices,
102
  "offsets": offsets,
 
124
  self._ensure_stream()
125
 
126
  # Offload stream waits for compute to finish.
127
+ compute_event = torch.cuda.current_stream(self._device).record_event()
 
128
  self._offload_stream.wait_event(compute_event)
129
 
130
  offloaded_bytes = 0
 
140
  for i, mgd_idx in enumerate(indices):
141
  local = self._local(self._managed[mgd_idx])
142
  off, n = offsets[i]
143
+ cpu_flat[off:off + n].copy_(local.reshape(-1),
144
+ non_blocking=True)
145
 
146
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
147
 
148
  # Wait for all D2H copies to land, then free GPU storage.
149
  self._offload_stream.synchronize()
150
  for t in self._managed:
151
+ storage = self._local(t).untyped_storage()
152
+ if storage.size() != 0:
153
+ storage.resize_(0)
154
+ else:
155
+ raise RuntimeError(
156
+ f"Tensor storage is already freed (size=0) before offload. "
157
+ f"This indicates a double-free or external interference. "
158
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
159
+ )
160
 
161
  if not self._logged:
162
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
 
179
  # Re-allocate all GPU storages first.
180
  for t in self._managed:
181
  local = self._local(t)
182
+ storage = local.untyped_storage()
183
+ if storage.size() != 0:
184
+ raise RuntimeError(
185
+ f"Storage should have been freed (size=0) before reload, "
186
+ f"but got size={storage.size()}. "
187
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
188
+ )
189
+ storage.resize_(self._storage_nbytes[id(t)])
190
 
191
  # Per-tensor H2D copies from CPU flat buffer slices.
192
  # non_blocking=True with pinned source allows DMA overlap.
 
198
  for i, mgd_idx in enumerate(indices):
199
  local = self._local(self._managed[mgd_idx])
200
  off, n = offsets[i]
201
+ local.reshape(-1).copy_(cpu_flat[off:off + n],
202
+ non_blocking=True)
203
 
204
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
205
 
build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py CHANGED
@@ -163,9 +163,10 @@ def construct_shard_mesh(
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
- # This avoids a non-collective dist.new_group() call, which would
167
- # deadlock when only a subset of ranks call this function (e.g. expert
168
- # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately).
 
169
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
170
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
171
  if key not in _ranks_to_dist_cache:
@@ -207,22 +208,25 @@ def construct_shard_mesh(
207
  assert len(shard_placements) == len(set(shard_placements))
208
 
209
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
210
- # All ranks must call dist.new_group in the same order, even though each
211
- # rank only joins one group.
 
 
212
  def _cache_key(t: torch.Tensor) -> tuple:
213
  return (*t.shape, *t.flatten().tolist())
214
 
215
  my_key = None
216
  for sm in shard_meshes:
217
- key = _cache_key(sm)
218
  if (my_rank == sm).any().item():
 
219
  assert my_key is None, "Rank appears in multiple shard groups"
220
  my_key = key
221
- if key not in _ranks_to_dist_cache:
222
- pg = dist.new_group(sm.flatten().tolist())
223
- _ranks_to_dist_cache[key] = (
224
- DeviceMesh(device_type="cuda", mesh=sm),
225
- pg,
226
- )
 
227
 
228
  return (*_ranks_to_dist_cache[my_key], shard_placements)
 
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
+ # Reuses the mesh's existing ProcessGroup directly, avoiding the
167
+ # overhead of dist.new_group(). The standard path below also handles
168
+ # subset calls safely via use_local_synchronization=True, but this
169
+ # fast path is still beneficial for the common 1D shard case.
170
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
171
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
172
  if key not in _ranks_to_dist_cache:
 
208
  assert len(shard_placements) == len(set(shard_placements))
209
 
210
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
211
+ # Each rank only creates the group it belongs to, using
212
+ # use_local_synchronization=True so that only group members need to
213
+ # coordinate. This avoids deadlocks when different PP stages call
214
+ # construct_shard_mesh for different parameters.
215
  def _cache_key(t: torch.Tensor) -> tuple:
216
  return (*t.shape, *t.flatten().tolist())
217
 
218
  my_key = None
219
  for sm in shard_meshes:
 
220
  if (my_rank == sm).any().item():
221
+ key = _cache_key(sm)
222
  assert my_key is None, "Rank appears in multiple shard groups"
223
  my_key = key
224
+ if key not in _ranks_to_dist_cache:
225
+ pg = dist.new_group(sm.flatten().tolist(),
226
+ use_local_synchronization=True)
227
+ _ranks_to_dist_cache[key] = (
228
+ DeviceMesh(device_type="cuda", mesh=sm),
229
+ pg,
230
+ )
231
 
232
  return (*_ranks_to_dist_cache[my_key], shard_placements)
build/torch210-cxx11-cu128-x86_64-linux/muon.py CHANGED
@@ -8,7 +8,7 @@ import torch.distributed as dist
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
- from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
- expert_keys=None,
211
- cpu_offload=False):
212
  defaults = dict(
213
  lr=lr,
214
  weight_decay=weight_decay,
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
229
  if param_group.get("use_muon", None) is None:
230
  raise ValueError(
231
  error_message.format(idx=_idx) + instruction_code)
232
-
233
  super().__init__(params, defaults)
234
 
235
  self.debug = debug
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
243
  self.chunk_size = chunk_size
244
  self.use_distributed_muon = use_distributed_muon
245
  self.expert_keys = expert_keys
246
- self.cpu_offload = cpu_offload
247
- self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None
248
  self._offload_initialized = False
249
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
250
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
1008
  # D2H: offload optimizer states to CPU after computation.
1009
  if self.cpu_offload:
1010
  if not self._offload_initialized:
 
 
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
- # Checkpoint support for cpu_offload
1019
  # ------------------------------------------------------------------
1020
 
1021
- def state_dict(self) -> dict:
1022
- """Return optimizer state dict, reloading offloaded states first.
1023
-
1024
- When ``cpu_offload=True``, optimizer state tensors have their GPU
1025
- storage freed (``resize_(0)``) between steps. We reload them,
1026
- snapshot the state dict, then re-offload so the optimizer stays
1027
- in the expected post-step state. The returned dict holds cloned
1028
- tensors so they remain valid after the re-offload frees the
1029
- originals' GPU storage.
1030
- """
1031
- if self.cpu_offload and self._offload_initialized:
 
 
 
 
 
 
 
 
 
1032
  self._cpu_offload_pool.reload()
1033
  torch.cuda.current_stream().synchronize()
1034
- sd = super().state_dict()
1035
- if self.cpu_offload and self._offload_initialized:
1036
- # Clone state tensors so the returned dict survives re-offload
1037
- # (which frees GPU storage on the originals via resize_(0)).
1038
- for k in sd["state"]:
1039
- sd["state"][k] = {
1040
- sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
1041
- for sk, sv in sd["state"][k].items()
1042
- }
1043
- self._cpu_offload_pool.offload()
1044
- return sd
1045
 
1046
- def load_state_dict(self, state_dict: dict) -> None:
1047
- """Load optimizer state dict, then offload states if needed.
 
1048
 
1049
- After ``super().load_state_dict()`` populates GPU tensors, we
1050
- re-register them with the offload pool and offload to CPU so the
1051
- optimizer is in the same post-step state (GPU storage freed).
1052
- """
1053
- # If states were offloaded, reload first so storage sizes are
1054
- # correct for super().load_state_dict() to overwrite.
1055
- if self.cpu_offload and self._offload_initialized:
1056
- self._cpu_offload_pool.reload()
1057
- torch.cuda.current_stream().synchronize()
1058
 
 
 
 
1059
  super().load_state_dict(state_dict)
1060
 
1061
- if self.cpu_offload:
1062
- # Re-create the offload pool since state tensors may be new
1063
- # objects after load_state_dict.
1064
- self._cpu_offload_pool = CPUOffloadPool()
1065
- self._offload_initialized = False
1066
- self._register_states_for_offload()
1067
- self._offload_initialized = True
1068
- self._cpu_offload_pool.offload()
 
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
+ from .adamw import _placement_cache, _tensor_cache, step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
 
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
+ expert_keys=None):
 
211
  defaults = dict(
212
  lr=lr,
213
  weight_decay=weight_decay,
 
228
  if param_group.get("use_muon", None) is None:
229
  raise ValueError(
230
  error_message.format(idx=_idx) + instruction_code)
 
231
  super().__init__(params, defaults)
232
 
233
  self.debug = debug
 
241
  self.chunk_size = chunk_size
242
  self.use_distributed_muon = use_distributed_muon
243
  self.expert_keys = expert_keys
244
+ self.cpu_offload = False
245
+ self._cpu_offload_pool: CPUOffloadPool | None = None
246
  self._offload_initialized = False
247
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
248
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
 
1006
  # D2H: offload optimizer states to CPU after computation.
1007
  if self.cpu_offload:
1008
  if not self._offload_initialized:
1009
+ if self._cpu_offload_pool is None:
1010
+ self._cpu_offload_pool = CPUOffloadPool()
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
 
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
+ # CPU offload public helpers
1019
  # ------------------------------------------------------------------
1020
 
1021
+ def turn_on_cpu_offload(self):
1022
+ """Enable CPU offload for optimizer states."""
1023
+ if self.cpu_offload:
1024
+ return
1025
+ logger.info("[Muon] turn_on_cpu_offload")
1026
+ self.cpu_offload = True
1027
+ if not self.state:
1028
+ return
1029
+ self._cpu_offload_pool = CPUOffloadPool()
1030
+ self._offload_initialized = False
1031
+ self._register_states_for_offload()
1032
+ self._offload_initialized = True
1033
+ self._cpu_offload_pool.offload()
1034
+
1035
+ def turn_off_cpu_offload(self):
1036
+ """Disable CPU offload and keep optimizer states resident on GPU."""
1037
+ if not self.cpu_offload:
1038
+ return
1039
+ logger.info("[Muon] turn_off_cpu_offload")
1040
+ if self._offload_initialized:
1041
  self._cpu_offload_pool.reload()
1042
  torch.cuda.current_stream().synchronize()
1043
+ self._cpu_offload_pool = None
1044
+ self._offload_initialized = False
1045
+ self.cpu_offload = False
 
 
 
 
 
 
 
 
1046
 
1047
+ # ------------------------------------------------------------------
1048
+ # Checkpoint support for cpu_offload
1049
+ # ------------------------------------------------------------------
1050
 
1051
+ def state_dict(self) -> dict:
1052
+ if self.cpu_offload:
1053
+ raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
1054
+ return super().state_dict()
 
 
 
 
 
1055
 
1056
+ def load_state_dict(self, state_dict: dict) -> None:
1057
+ if self.cpu_offload:
1058
+ raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
1059
  super().load_state_dict(state_dict)
1060
 
1061
+ # Invalidate adamw.py's module-level tensor caches so that
1062
+ # the next step rebuilds them with the newly loaded state tensors.
1063
+ _placement_cache.clear()
1064
+ _tensor_cache.clear()
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_5b58933_dirty
3
- ops = torch.ops._optimizer_5b58933_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_5b58933_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_b68ea5b_dirty
3
+ ops = torch.ops._optimizer_b68ea5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_b68ea5b_dirty::{op_name}"
build/torch210-cxx11-cu130-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6869cfabdf45c7092d251846b3099287f8bccd5c5ebe7edf1a5fd21436324349
3
  size 2004728
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f72217eb59ce93935f593b0fdfc7f3bfc4e05f18ad9d5384c2325b27ad7ff136
3
  size 2004728
build/torch210-cxx11-cu130-x86_64-linux/cpu_offload.py CHANGED
@@ -68,7 +68,11 @@ class CPUOffloadPool:
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
- self._storage_nbytes[tid] = local.untyped_storage().size()
 
 
 
 
72
  self._managed.append(tensor)
73
 
74
  # ------------------------------------------------------------------
@@ -89,7 +93,10 @@ class CPUOffloadPool:
89
  indices.append(idx)
90
  offsets.append((off, n))
91
  off += n
92
- cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
 
 
 
93
  self._groups[dtype] = {
94
  "indices": indices,
95
  "offsets": offsets,
@@ -117,8 +124,7 @@ class CPUOffloadPool:
117
  self._ensure_stream()
118
 
119
  # Offload stream waits for compute to finish.
120
- compute_event = torch.cuda.current_stream(
121
- self._device).record_event()
122
  self._offload_stream.wait_event(compute_event)
123
 
124
  offloaded_bytes = 0
@@ -134,15 +140,23 @@ class CPUOffloadPool:
134
  for i, mgd_idx in enumerate(indices):
135
  local = self._local(self._managed[mgd_idx])
136
  off, n = offsets[i]
137
- cpu_flat[off:off + n].copy_(
138
- local.reshape(-1), non_blocking=True)
139
 
140
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
141
 
142
  # Wait for all D2H copies to land, then free GPU storage.
143
  self._offload_stream.synchronize()
144
  for t in self._managed:
145
- self._local(t).untyped_storage().resize_(0)
 
 
 
 
 
 
 
 
146
 
147
  if not self._logged:
148
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
@@ -165,7 +179,14 @@ class CPUOffloadPool:
165
  # Re-allocate all GPU storages first.
166
  for t in self._managed:
167
  local = self._local(t)
168
- local.untyped_storage().resize_(self._storage_nbytes[id(t)])
 
 
 
 
 
 
 
169
 
170
  # Per-tensor H2D copies from CPU flat buffer slices.
171
  # non_blocking=True with pinned source allows DMA overlap.
@@ -177,8 +198,8 @@ class CPUOffloadPool:
177
  for i, mgd_idx in enumerate(indices):
178
  local = self._local(self._managed[mgd_idx])
179
  off, n = offsets[i]
180
- local.reshape(-1).copy_(
181
- cpu_flat[off:off + n], non_blocking=True)
182
 
183
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
184
 
 
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
+ storage = local.untyped_storage()
72
+ # Skip tensors with empty storage (e.g. empty FSDP shards)
73
+ if storage.size() == 0:
74
+ return
75
+ self._storage_nbytes[tid] = storage.size()
76
  self._managed.append(tensor)
77
 
78
  # ------------------------------------------------------------------
 
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
+ cpu_flat = torch.empty(off,
97
+ dtype=dtype,
98
+ device="cpu",
99
+ pin_memory=True)
100
  self._groups[dtype] = {
101
  "indices": indices,
102
  "offsets": offsets,
 
124
  self._ensure_stream()
125
 
126
  # Offload stream waits for compute to finish.
127
+ compute_event = torch.cuda.current_stream(self._device).record_event()
 
128
  self._offload_stream.wait_event(compute_event)
129
 
130
  offloaded_bytes = 0
 
140
  for i, mgd_idx in enumerate(indices):
141
  local = self._local(self._managed[mgd_idx])
142
  off, n = offsets[i]
143
+ cpu_flat[off:off + n].copy_(local.reshape(-1),
144
+ non_blocking=True)
145
 
146
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
147
 
148
  # Wait for all D2H copies to land, then free GPU storage.
149
  self._offload_stream.synchronize()
150
  for t in self._managed:
151
+ storage = self._local(t).untyped_storage()
152
+ if storage.size() != 0:
153
+ storage.resize_(0)
154
+ else:
155
+ raise RuntimeError(
156
+ f"Tensor storage is already freed (size=0) before offload. "
157
+ f"This indicates a double-free or external interference. "
158
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
159
+ )
160
 
161
  if not self._logged:
162
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
 
179
  # Re-allocate all GPU storages first.
180
  for t in self._managed:
181
  local = self._local(t)
182
+ storage = local.untyped_storage()
183
+ if storage.size() != 0:
184
+ raise RuntimeError(
185
+ f"Storage should have been freed (size=0) before reload, "
186
+ f"but got size={storage.size()}. "
187
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
188
+ )
189
+ storage.resize_(self._storage_nbytes[id(t)])
190
 
191
  # Per-tensor H2D copies from CPU flat buffer slices.
192
  # non_blocking=True with pinned source allows DMA overlap.
 
198
  for i, mgd_idx in enumerate(indices):
199
  local = self._local(self._managed[mgd_idx])
200
  off, n = offsets[i]
201
+ local.reshape(-1).copy_(cpu_flat[off:off + n],
202
+ non_blocking=True)
203
 
204
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
205
 
build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py CHANGED
@@ -163,9 +163,10 @@ def construct_shard_mesh(
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
- # This avoids a non-collective dist.new_group() call, which would
167
- # deadlock when only a subset of ranks call this function (e.g. expert
168
- # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately).
 
169
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
170
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
171
  if key not in _ranks_to_dist_cache:
@@ -207,22 +208,25 @@ def construct_shard_mesh(
207
  assert len(shard_placements) == len(set(shard_placements))
208
 
209
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
210
- # All ranks must call dist.new_group in the same order, even though each
211
- # rank only joins one group.
 
 
212
  def _cache_key(t: torch.Tensor) -> tuple:
213
  return (*t.shape, *t.flatten().tolist())
214
 
215
  my_key = None
216
  for sm in shard_meshes:
217
- key = _cache_key(sm)
218
  if (my_rank == sm).any().item():
 
219
  assert my_key is None, "Rank appears in multiple shard groups"
220
  my_key = key
221
- if key not in _ranks_to_dist_cache:
222
- pg = dist.new_group(sm.flatten().tolist())
223
- _ranks_to_dist_cache[key] = (
224
- DeviceMesh(device_type="cuda", mesh=sm),
225
- pg,
226
- )
 
227
 
228
  return (*_ranks_to_dist_cache[my_key], shard_placements)
 
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
+ # Reuses the mesh's existing ProcessGroup directly, avoiding the
167
+ # overhead of dist.new_group(). The standard path below also handles
168
+ # subset calls safely via use_local_synchronization=True, but this
169
+ # fast path is still beneficial for the common 1D shard case.
170
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
171
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
172
  if key not in _ranks_to_dist_cache:
 
208
  assert len(shard_placements) == len(set(shard_placements))
209
 
210
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
211
+ # Each rank only creates the group it belongs to, using
212
+ # use_local_synchronization=True so that only group members need to
213
+ # coordinate. This avoids deadlocks when different PP stages call
214
+ # construct_shard_mesh for different parameters.
215
  def _cache_key(t: torch.Tensor) -> tuple:
216
  return (*t.shape, *t.flatten().tolist())
217
 
218
  my_key = None
219
  for sm in shard_meshes:
 
220
  if (my_rank == sm).any().item():
221
+ key = _cache_key(sm)
222
  assert my_key is None, "Rank appears in multiple shard groups"
223
  my_key = key
224
+ if key not in _ranks_to_dist_cache:
225
+ pg = dist.new_group(sm.flatten().tolist(),
226
+ use_local_synchronization=True)
227
+ _ranks_to_dist_cache[key] = (
228
+ DeviceMesh(device_type="cuda", mesh=sm),
229
+ pg,
230
+ )
231
 
232
  return (*_ranks_to_dist_cache[my_key], shard_placements)
build/torch210-cxx11-cu130-x86_64-linux/muon.py CHANGED
@@ -8,7 +8,7 @@ import torch.distributed as dist
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
- from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
- expert_keys=None,
211
- cpu_offload=False):
212
  defaults = dict(
213
  lr=lr,
214
  weight_decay=weight_decay,
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
229
  if param_group.get("use_muon", None) is None:
230
  raise ValueError(
231
  error_message.format(idx=_idx) + instruction_code)
232
-
233
  super().__init__(params, defaults)
234
 
235
  self.debug = debug
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
243
  self.chunk_size = chunk_size
244
  self.use_distributed_muon = use_distributed_muon
245
  self.expert_keys = expert_keys
246
- self.cpu_offload = cpu_offload
247
- self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None
248
  self._offload_initialized = False
249
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
250
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
1008
  # D2H: offload optimizer states to CPU after computation.
1009
  if self.cpu_offload:
1010
  if not self._offload_initialized:
 
 
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
- # Checkpoint support for cpu_offload
1019
  # ------------------------------------------------------------------
1020
 
1021
- def state_dict(self) -> dict:
1022
- """Return optimizer state dict, reloading offloaded states first.
1023
-
1024
- When ``cpu_offload=True``, optimizer state tensors have their GPU
1025
- storage freed (``resize_(0)``) between steps. We reload them,
1026
- snapshot the state dict, then re-offload so the optimizer stays
1027
- in the expected post-step state. The returned dict holds cloned
1028
- tensors so they remain valid after the re-offload frees the
1029
- originals' GPU storage.
1030
- """
1031
- if self.cpu_offload and self._offload_initialized:
 
 
 
 
 
 
 
 
 
1032
  self._cpu_offload_pool.reload()
1033
  torch.cuda.current_stream().synchronize()
1034
- sd = super().state_dict()
1035
- if self.cpu_offload and self._offload_initialized:
1036
- # Clone state tensors so the returned dict survives re-offload
1037
- # (which frees GPU storage on the originals via resize_(0)).
1038
- for k in sd["state"]:
1039
- sd["state"][k] = {
1040
- sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
1041
- for sk, sv in sd["state"][k].items()
1042
- }
1043
- self._cpu_offload_pool.offload()
1044
- return sd
1045
 
1046
- def load_state_dict(self, state_dict: dict) -> None:
1047
- """Load optimizer state dict, then offload states if needed.
 
1048
 
1049
- After ``super().load_state_dict()`` populates GPU tensors, we
1050
- re-register them with the offload pool and offload to CPU so the
1051
- optimizer is in the same post-step state (GPU storage freed).
1052
- """
1053
- # If states were offloaded, reload first so storage sizes are
1054
- # correct for super().load_state_dict() to overwrite.
1055
- if self.cpu_offload and self._offload_initialized:
1056
- self._cpu_offload_pool.reload()
1057
- torch.cuda.current_stream().synchronize()
1058
 
 
 
 
1059
  super().load_state_dict(state_dict)
1060
 
1061
- if self.cpu_offload:
1062
- # Re-create the offload pool since state tensors may be new
1063
- # objects after load_state_dict.
1064
- self._cpu_offload_pool = CPUOffloadPool()
1065
- self._offload_initialized = False
1066
- self._register_states_for_offload()
1067
- self._offload_initialized = True
1068
- self._cpu_offload_pool.offload()
 
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
+ from .adamw import _placement_cache, _tensor_cache, step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
 
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
+ expert_keys=None):
 
211
  defaults = dict(
212
  lr=lr,
213
  weight_decay=weight_decay,
 
228
  if param_group.get("use_muon", None) is None:
229
  raise ValueError(
230
  error_message.format(idx=_idx) + instruction_code)
 
231
  super().__init__(params, defaults)
232
 
233
  self.debug = debug
 
241
  self.chunk_size = chunk_size
242
  self.use_distributed_muon = use_distributed_muon
243
  self.expert_keys = expert_keys
244
+ self.cpu_offload = False
245
+ self._cpu_offload_pool: CPUOffloadPool | None = None
246
  self._offload_initialized = False
247
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
248
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
 
1006
  # D2H: offload optimizer states to CPU after computation.
1007
  if self.cpu_offload:
1008
  if not self._offload_initialized:
1009
+ if self._cpu_offload_pool is None:
1010
+ self._cpu_offload_pool = CPUOffloadPool()
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
 
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
+ # CPU offload public helpers
1019
  # ------------------------------------------------------------------
1020
 
1021
+ def turn_on_cpu_offload(self):
1022
+ """Enable CPU offload for optimizer states."""
1023
+ if self.cpu_offload:
1024
+ return
1025
+ logger.info("[Muon] turn_on_cpu_offload")
1026
+ self.cpu_offload = True
1027
+ if not self.state:
1028
+ return
1029
+ self._cpu_offload_pool = CPUOffloadPool()
1030
+ self._offload_initialized = False
1031
+ self._register_states_for_offload()
1032
+ self._offload_initialized = True
1033
+ self._cpu_offload_pool.offload()
1034
+
1035
+ def turn_off_cpu_offload(self):
1036
+ """Disable CPU offload and keep optimizer states resident on GPU."""
1037
+ if not self.cpu_offload:
1038
+ return
1039
+ logger.info("[Muon] turn_off_cpu_offload")
1040
+ if self._offload_initialized:
1041
  self._cpu_offload_pool.reload()
1042
  torch.cuda.current_stream().synchronize()
1043
+ self._cpu_offload_pool = None
1044
+ self._offload_initialized = False
1045
+ self.cpu_offload = False
 
 
 
 
 
 
 
 
1046
 
1047
+ # ------------------------------------------------------------------
1048
+ # Checkpoint support for cpu_offload
1049
+ # ------------------------------------------------------------------
1050
 
1051
+ def state_dict(self) -> dict:
1052
+ if self.cpu_offload:
1053
+ raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
1054
+ return super().state_dict()
 
 
 
 
 
1055
 
1056
+ def load_state_dict(self, state_dict: dict) -> None:
1057
+ if self.cpu_offload:
1058
+ raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
1059
  super().load_state_dict(state_dict)
1060
 
1061
+ # Invalidate adamw.py's module-level tensor caches so that
1062
+ # the next step rebuilds them with the newly loaded state tensors.
1063
+ _placement_cache.clear()
1064
+ _tensor_cache.clear()
 
 
 
 
build/torch210-cxx11-rocm70-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_5b58933_dirty
3
- ops = torch.ops._optimizer_5b58933_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_5b58933_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_b68ea5b_dirty
3
+ ops = torch.ops._optimizer_b68ea5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_b68ea5b_dirty::{op_name}"
build/torch210-cxx11-rocm70-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0102e10121a43f6d5d59a23f2c0e21d88469cc4597d84f7d48b64b0fabfeacdb
3
  size 1866400
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df63b7dafe62fa5e910de1123729cbe3496015e2c0110785d9bd510bf65c2eaa
3
  size 1866400
build/torch210-cxx11-rocm70-x86_64-linux/cpu_offload.py CHANGED
@@ -68,7 +68,11 @@ class CPUOffloadPool:
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
- self._storage_nbytes[tid] = local.untyped_storage().size()
 
 
 
 
72
  self._managed.append(tensor)
73
 
74
  # ------------------------------------------------------------------
@@ -89,7 +93,10 @@ class CPUOffloadPool:
89
  indices.append(idx)
90
  offsets.append((off, n))
91
  off += n
92
- cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
 
 
 
93
  self._groups[dtype] = {
94
  "indices": indices,
95
  "offsets": offsets,
@@ -117,8 +124,7 @@ class CPUOffloadPool:
117
  self._ensure_stream()
118
 
119
  # Offload stream waits for compute to finish.
120
- compute_event = torch.cuda.current_stream(
121
- self._device).record_event()
122
  self._offload_stream.wait_event(compute_event)
123
 
124
  offloaded_bytes = 0
@@ -134,15 +140,23 @@ class CPUOffloadPool:
134
  for i, mgd_idx in enumerate(indices):
135
  local = self._local(self._managed[mgd_idx])
136
  off, n = offsets[i]
137
- cpu_flat[off:off + n].copy_(
138
- local.reshape(-1), non_blocking=True)
139
 
140
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
141
 
142
  # Wait for all D2H copies to land, then free GPU storage.
143
  self._offload_stream.synchronize()
144
  for t in self._managed:
145
- self._local(t).untyped_storage().resize_(0)
 
 
 
 
 
 
 
 
146
 
147
  if not self._logged:
148
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
@@ -165,7 +179,14 @@ class CPUOffloadPool:
165
  # Re-allocate all GPU storages first.
166
  for t in self._managed:
167
  local = self._local(t)
168
- local.untyped_storage().resize_(self._storage_nbytes[id(t)])
 
 
 
 
 
 
 
169
 
170
  # Per-tensor H2D copies from CPU flat buffer slices.
171
  # non_blocking=True with pinned source allows DMA overlap.
@@ -177,8 +198,8 @@ class CPUOffloadPool:
177
  for i, mgd_idx in enumerate(indices):
178
  local = self._local(self._managed[mgd_idx])
179
  off, n = offsets[i]
180
- local.reshape(-1).copy_(
181
- cpu_flat[off:off + n], non_blocking=True)
182
 
183
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
184
 
 
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
+ storage = local.untyped_storage()
72
+ # Skip tensors with empty storage (e.g. empty FSDP shards)
73
+ if storage.size() == 0:
74
+ return
75
+ self._storage_nbytes[tid] = storage.size()
76
  self._managed.append(tensor)
77
 
78
  # ------------------------------------------------------------------
 
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
+ cpu_flat = torch.empty(off,
97
+ dtype=dtype,
98
+ device="cpu",
99
+ pin_memory=True)
100
  self._groups[dtype] = {
101
  "indices": indices,
102
  "offsets": offsets,
 
124
  self._ensure_stream()
125
 
126
  # Offload stream waits for compute to finish.
127
+ compute_event = torch.cuda.current_stream(self._device).record_event()
 
128
  self._offload_stream.wait_event(compute_event)
129
 
130
  offloaded_bytes = 0
 
140
  for i, mgd_idx in enumerate(indices):
141
  local = self._local(self._managed[mgd_idx])
142
  off, n = offsets[i]
143
+ cpu_flat[off:off + n].copy_(local.reshape(-1),
144
+ non_blocking=True)
145
 
146
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
147
 
148
  # Wait for all D2H copies to land, then free GPU storage.
149
  self._offload_stream.synchronize()
150
  for t in self._managed:
151
+ storage = self._local(t).untyped_storage()
152
+ if storage.size() != 0:
153
+ storage.resize_(0)
154
+ else:
155
+ raise RuntimeError(
156
+ f"Tensor storage is already freed (size=0) before offload. "
157
+ f"This indicates a double-free or external interference. "
158
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
159
+ )
160
 
161
  if not self._logged:
162
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
 
179
  # Re-allocate all GPU storages first.
180
  for t in self._managed:
181
  local = self._local(t)
182
+ storage = local.untyped_storage()
183
+ if storage.size() != 0:
184
+ raise RuntimeError(
185
+ f"Storage should have been freed (size=0) before reload, "
186
+ f"but got size={storage.size()}. "
187
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
188
+ )
189
+ storage.resize_(self._storage_nbytes[id(t)])
190
 
191
  # Per-tensor H2D copies from CPU flat buffer slices.
192
  # non_blocking=True with pinned source allows DMA overlap.
 
198
  for i, mgd_idx in enumerate(indices):
199
  local = self._local(self._managed[mgd_idx])
200
  off, n = offsets[i]
201
+ local.reshape(-1).copy_(cpu_flat[off:off + n],
202
+ non_blocking=True)
203
 
204
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
205
 
build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py CHANGED
@@ -163,9 +163,10 @@ def construct_shard_mesh(
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
- # This avoids a non-collective dist.new_group() call, which would
167
- # deadlock when only a subset of ranks call this function (e.g. expert
168
- # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately).
 
169
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
170
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
171
  if key not in _ranks_to_dist_cache:
@@ -207,22 +208,25 @@ def construct_shard_mesh(
207
  assert len(shard_placements) == len(set(shard_placements))
208
 
209
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
210
- # All ranks must call dist.new_group in the same order, even though each
211
- # rank only joins one group.
 
 
212
  def _cache_key(t: torch.Tensor) -> tuple:
213
  return (*t.shape, *t.flatten().tolist())
214
 
215
  my_key = None
216
  for sm in shard_meshes:
217
- key = _cache_key(sm)
218
  if (my_rank == sm).any().item():
 
219
  assert my_key is None, "Rank appears in multiple shard groups"
220
  my_key = key
221
- if key not in _ranks_to_dist_cache:
222
- pg = dist.new_group(sm.flatten().tolist())
223
- _ranks_to_dist_cache[key] = (
224
- DeviceMesh(device_type="cuda", mesh=sm),
225
- pg,
226
- )
 
227
 
228
  return (*_ranks_to_dist_cache[my_key], shard_placements)
 
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
+ # Reuses the mesh's existing ProcessGroup directly, avoiding the
167
+ # overhead of dist.new_group(). The standard path below also handles
168
+ # subset calls safely via use_local_synchronization=True, but this
169
+ # fast path is still beneficial for the common 1D shard case.
170
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
171
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
172
  if key not in _ranks_to_dist_cache:
 
208
  assert len(shard_placements) == len(set(shard_placements))
209
 
210
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
211
+ # Each rank only creates the group it belongs to, using
212
+ # use_local_synchronization=True so that only group members need to
213
+ # coordinate. This avoids deadlocks when different PP stages call
214
+ # construct_shard_mesh for different parameters.
215
  def _cache_key(t: torch.Tensor) -> tuple:
216
  return (*t.shape, *t.flatten().tolist())
217
 
218
  my_key = None
219
  for sm in shard_meshes:
 
220
  if (my_rank == sm).any().item():
221
+ key = _cache_key(sm)
222
  assert my_key is None, "Rank appears in multiple shard groups"
223
  my_key = key
224
+ if key not in _ranks_to_dist_cache:
225
+ pg = dist.new_group(sm.flatten().tolist(),
226
+ use_local_synchronization=True)
227
+ _ranks_to_dist_cache[key] = (
228
+ DeviceMesh(device_type="cuda", mesh=sm),
229
+ pg,
230
+ )
231
 
232
  return (*_ranks_to_dist_cache[my_key], shard_placements)
build/torch210-cxx11-rocm70-x86_64-linux/muon.py CHANGED
@@ -8,7 +8,7 @@ import torch.distributed as dist
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
- from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
- expert_keys=None,
211
- cpu_offload=False):
212
  defaults = dict(
213
  lr=lr,
214
  weight_decay=weight_decay,
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
229
  if param_group.get("use_muon", None) is None:
230
  raise ValueError(
231
  error_message.format(idx=_idx) + instruction_code)
232
-
233
  super().__init__(params, defaults)
234
 
235
  self.debug = debug
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
243
  self.chunk_size = chunk_size
244
  self.use_distributed_muon = use_distributed_muon
245
  self.expert_keys = expert_keys
246
- self.cpu_offload = cpu_offload
247
- self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None
248
  self._offload_initialized = False
249
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
250
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
1008
  # D2H: offload optimizer states to CPU after computation.
1009
  if self.cpu_offload:
1010
  if not self._offload_initialized:
 
 
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
- # Checkpoint support for cpu_offload
1019
  # ------------------------------------------------------------------
1020
 
1021
- def state_dict(self) -> dict:
1022
- """Return optimizer state dict, reloading offloaded states first.
1023
-
1024
- When ``cpu_offload=True``, optimizer state tensors have their GPU
1025
- storage freed (``resize_(0)``) between steps. We reload them,
1026
- snapshot the state dict, then re-offload so the optimizer stays
1027
- in the expected post-step state. The returned dict holds cloned
1028
- tensors so they remain valid after the re-offload frees the
1029
- originals' GPU storage.
1030
- """
1031
- if self.cpu_offload and self._offload_initialized:
 
 
 
 
 
 
 
 
 
1032
  self._cpu_offload_pool.reload()
1033
  torch.cuda.current_stream().synchronize()
1034
- sd = super().state_dict()
1035
- if self.cpu_offload and self._offload_initialized:
1036
- # Clone state tensors so the returned dict survives re-offload
1037
- # (which frees GPU storage on the originals via resize_(0)).
1038
- for k in sd["state"]:
1039
- sd["state"][k] = {
1040
- sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
1041
- for sk, sv in sd["state"][k].items()
1042
- }
1043
- self._cpu_offload_pool.offload()
1044
- return sd
1045
 
1046
- def load_state_dict(self, state_dict: dict) -> None:
1047
- """Load optimizer state dict, then offload states if needed.
 
1048
 
1049
- After ``super().load_state_dict()`` populates GPU tensors, we
1050
- re-register them with the offload pool and offload to CPU so the
1051
- optimizer is in the same post-step state (GPU storage freed).
1052
- """
1053
- # If states were offloaded, reload first so storage sizes are
1054
- # correct for super().load_state_dict() to overwrite.
1055
- if self.cpu_offload and self._offload_initialized:
1056
- self._cpu_offload_pool.reload()
1057
- torch.cuda.current_stream().synchronize()
1058
 
 
 
 
1059
  super().load_state_dict(state_dict)
1060
 
1061
- if self.cpu_offload:
1062
- # Re-create the offload pool since state tensors may be new
1063
- # objects after load_state_dict.
1064
- self._cpu_offload_pool = CPUOffloadPool()
1065
- self._offload_initialized = False
1066
- self._register_states_for_offload()
1067
- self._offload_initialized = True
1068
- self._cpu_offload_pool.offload()
 
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
+ from .adamw import _placement_cache, _tensor_cache, step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
 
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
+ expert_keys=None):
 
211
  defaults = dict(
212
  lr=lr,
213
  weight_decay=weight_decay,
 
228
  if param_group.get("use_muon", None) is None:
229
  raise ValueError(
230
  error_message.format(idx=_idx) + instruction_code)
 
231
  super().__init__(params, defaults)
232
 
233
  self.debug = debug
 
241
  self.chunk_size = chunk_size
242
  self.use_distributed_muon = use_distributed_muon
243
  self.expert_keys = expert_keys
244
+ self.cpu_offload = False
245
+ self._cpu_offload_pool: CPUOffloadPool | None = None
246
  self._offload_initialized = False
247
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
248
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
 
1006
  # D2H: offload optimizer states to CPU after computation.
1007
  if self.cpu_offload:
1008
  if not self._offload_initialized:
1009
+ if self._cpu_offload_pool is None:
1010
+ self._cpu_offload_pool = CPUOffloadPool()
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
 
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
+ # CPU offload public helpers
1019
  # ------------------------------------------------------------------
1020
 
1021
+ def turn_on_cpu_offload(self):
1022
+ """Enable CPU offload for optimizer states."""
1023
+ if self.cpu_offload:
1024
+ return
1025
+ logger.info("[Muon] turn_on_cpu_offload")
1026
+ self.cpu_offload = True
1027
+ if not self.state:
1028
+ return
1029
+ self._cpu_offload_pool = CPUOffloadPool()
1030
+ self._offload_initialized = False
1031
+ self._register_states_for_offload()
1032
+ self._offload_initialized = True
1033
+ self._cpu_offload_pool.offload()
1034
+
1035
+ def turn_off_cpu_offload(self):
1036
+ """Disable CPU offload and keep optimizer states resident on GPU."""
1037
+ if not self.cpu_offload:
1038
+ return
1039
+ logger.info("[Muon] turn_off_cpu_offload")
1040
+ if self._offload_initialized:
1041
  self._cpu_offload_pool.reload()
1042
  torch.cuda.current_stream().synchronize()
1043
+ self._cpu_offload_pool = None
1044
+ self._offload_initialized = False
1045
+ self.cpu_offload = False
 
 
 
 
 
 
 
 
1046
 
1047
+ # ------------------------------------------------------------------
1048
+ # Checkpoint support for cpu_offload
1049
+ # ------------------------------------------------------------------
1050
 
1051
+ def state_dict(self) -> dict:
1052
+ if self.cpu_offload:
1053
+ raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
1054
+ return super().state_dict()
 
 
 
 
 
1055
 
1056
+ def load_state_dict(self, state_dict: dict) -> None:
1057
+ if self.cpu_offload:
1058
+ raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
1059
  super().load_state_dict(state_dict)
1060
 
1061
+ # Invalidate adamw.py's module-level tensor caches so that
1062
+ # the next step rebuilds them with the newly loaded state tensors.
1063
+ _placement_cache.clear()
1064
+ _tensor_cache.clear()
 
 
 
 
build/torch210-cxx11-rocm71-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_5b58933_dirty
3
- ops = torch.ops._optimizer_5b58933_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_5b58933_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_b68ea5b_dirty
3
+ ops = torch.ops._optimizer_b68ea5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_b68ea5b_dirty::{op_name}"
build/torch210-cxx11-rocm71-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f41709878a4def27b12f4f9a4f5b767027fb33141e775f64ad04d434fcbe33d9
3
  size 1866112
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7215ee4575fa44f482a98e563c5af2d60089e36d32fe8a3dcffe3fb5f587300f
3
  size 1866112
build/torch210-cxx11-rocm71-x86_64-linux/cpu_offload.py CHANGED
@@ -68,7 +68,11 @@ class CPUOffloadPool:
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
- self._storage_nbytes[tid] = local.untyped_storage().size()
 
 
 
 
72
  self._managed.append(tensor)
73
 
74
  # ------------------------------------------------------------------
@@ -89,7 +93,10 @@ class CPUOffloadPool:
89
  indices.append(idx)
90
  offsets.append((off, n))
91
  off += n
92
- cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
 
 
 
93
  self._groups[dtype] = {
94
  "indices": indices,
95
  "offsets": offsets,
@@ -117,8 +124,7 @@ class CPUOffloadPool:
117
  self._ensure_stream()
118
 
119
  # Offload stream waits for compute to finish.
120
- compute_event = torch.cuda.current_stream(
121
- self._device).record_event()
122
  self._offload_stream.wait_event(compute_event)
123
 
124
  offloaded_bytes = 0
@@ -134,15 +140,23 @@ class CPUOffloadPool:
134
  for i, mgd_idx in enumerate(indices):
135
  local = self._local(self._managed[mgd_idx])
136
  off, n = offsets[i]
137
- cpu_flat[off:off + n].copy_(
138
- local.reshape(-1), non_blocking=True)
139
 
140
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
141
 
142
  # Wait for all D2H copies to land, then free GPU storage.
143
  self._offload_stream.synchronize()
144
  for t in self._managed:
145
- self._local(t).untyped_storage().resize_(0)
 
 
 
 
 
 
 
 
146
 
147
  if not self._logged:
148
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
@@ -165,7 +179,14 @@ class CPUOffloadPool:
165
  # Re-allocate all GPU storages first.
166
  for t in self._managed:
167
  local = self._local(t)
168
- local.untyped_storage().resize_(self._storage_nbytes[id(t)])
 
 
 
 
 
 
 
169
 
170
  # Per-tensor H2D copies from CPU flat buffer slices.
171
  # non_blocking=True with pinned source allows DMA overlap.
@@ -177,8 +198,8 @@ class CPUOffloadPool:
177
  for i, mgd_idx in enumerate(indices):
178
  local = self._local(self._managed[mgd_idx])
179
  off, n = offsets[i]
180
- local.reshape(-1).copy_(
181
- cpu_flat[off:off + n], non_blocking=True)
182
 
183
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
184
 
 
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
+ storage = local.untyped_storage()
72
+ # Skip tensors with empty storage (e.g. empty FSDP shards)
73
+ if storage.size() == 0:
74
+ return
75
+ self._storage_nbytes[tid] = storage.size()
76
  self._managed.append(tensor)
77
 
78
  # ------------------------------------------------------------------
 
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
+ cpu_flat = torch.empty(off,
97
+ dtype=dtype,
98
+ device="cpu",
99
+ pin_memory=True)
100
  self._groups[dtype] = {
101
  "indices": indices,
102
  "offsets": offsets,
 
124
  self._ensure_stream()
125
 
126
  # Offload stream waits for compute to finish.
127
+ compute_event = torch.cuda.current_stream(self._device).record_event()
 
128
  self._offload_stream.wait_event(compute_event)
129
 
130
  offloaded_bytes = 0
 
140
  for i, mgd_idx in enumerate(indices):
141
  local = self._local(self._managed[mgd_idx])
142
  off, n = offsets[i]
143
+ cpu_flat[off:off + n].copy_(local.reshape(-1),
144
+ non_blocking=True)
145
 
146
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
147
 
148
  # Wait for all D2H copies to land, then free GPU storage.
149
  self._offload_stream.synchronize()
150
  for t in self._managed:
151
+ storage = self._local(t).untyped_storage()
152
+ if storage.size() != 0:
153
+ storage.resize_(0)
154
+ else:
155
+ raise RuntimeError(
156
+ f"Tensor storage is already freed (size=0) before offload. "
157
+ f"This indicates a double-free or external interference. "
158
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
159
+ )
160
 
161
  if not self._logged:
162
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
 
179
  # Re-allocate all GPU storages first.
180
  for t in self._managed:
181
  local = self._local(t)
182
+ storage = local.untyped_storage()
183
+ if storage.size() != 0:
184
+ raise RuntimeError(
185
+ f"Storage should have been freed (size=0) before reload, "
186
+ f"but got size={storage.size()}. "
187
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
188
+ )
189
+ storage.resize_(self._storage_nbytes[id(t)])
190
 
191
  # Per-tensor H2D copies from CPU flat buffer slices.
192
  # non_blocking=True with pinned source allows DMA overlap.
 
198
  for i, mgd_idx in enumerate(indices):
199
  local = self._local(self._managed[mgd_idx])
200
  off, n = offsets[i]
201
+ local.reshape(-1).copy_(cpu_flat[off:off + n],
202
+ non_blocking=True)
203
 
204
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
205
 
build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py CHANGED
@@ -163,9 +163,10 @@ def construct_shard_mesh(
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
- # This avoids a non-collective dist.new_group() call, which would
167
- # deadlock when only a subset of ranks call this function (e.g. expert
168
- # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately).
 
169
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
170
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
171
  if key not in _ranks_to_dist_cache:
@@ -207,22 +208,25 @@ def construct_shard_mesh(
207
  assert len(shard_placements) == len(set(shard_placements))
208
 
209
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
210
- # All ranks must call dist.new_group in the same order, even though each
211
- # rank only joins one group.
 
 
212
  def _cache_key(t: torch.Tensor) -> tuple:
213
  return (*t.shape, *t.flatten().tolist())
214
 
215
  my_key = None
216
  for sm in shard_meshes:
217
- key = _cache_key(sm)
218
  if (my_rank == sm).any().item():
 
219
  assert my_key is None, "Rank appears in multiple shard groups"
220
  my_key = key
221
- if key not in _ranks_to_dist_cache:
222
- pg = dist.new_group(sm.flatten().tolist())
223
- _ranks_to_dist_cache[key] = (
224
- DeviceMesh(device_type="cuda", mesh=sm),
225
- pg,
226
- )
 
227
 
228
  return (*_ranks_to_dist_cache[my_key], shard_placements)
 
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
+ # Reuses the mesh's existing ProcessGroup directly, avoiding the
167
+ # overhead of dist.new_group(). The standard path below also handles
168
+ # subset calls safely via use_local_synchronization=True, but this
169
+ # fast path is still beneficial for the common 1D shard case.
170
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
171
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
172
  if key not in _ranks_to_dist_cache:
 
208
  assert len(shard_placements) == len(set(shard_placements))
209
 
210
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
211
+ # Each rank only creates the group it belongs to, using
212
+ # use_local_synchronization=True so that only group members need to
213
+ # coordinate. This avoids deadlocks when different PP stages call
214
+ # construct_shard_mesh for different parameters.
215
  def _cache_key(t: torch.Tensor) -> tuple:
216
  return (*t.shape, *t.flatten().tolist())
217
 
218
  my_key = None
219
  for sm in shard_meshes:
 
220
  if (my_rank == sm).any().item():
221
+ key = _cache_key(sm)
222
  assert my_key is None, "Rank appears in multiple shard groups"
223
  my_key = key
224
+ if key not in _ranks_to_dist_cache:
225
+ pg = dist.new_group(sm.flatten().tolist(),
226
+ use_local_synchronization=True)
227
+ _ranks_to_dist_cache[key] = (
228
+ DeviceMesh(device_type="cuda", mesh=sm),
229
+ pg,
230
+ )
231
 
232
  return (*_ranks_to_dist_cache[my_key], shard_placements)
build/torch210-cxx11-rocm71-x86_64-linux/muon.py CHANGED
@@ -8,7 +8,7 @@ import torch.distributed as dist
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
- from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
- expert_keys=None,
211
- cpu_offload=False):
212
  defaults = dict(
213
  lr=lr,
214
  weight_decay=weight_decay,
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
229
  if param_group.get("use_muon", None) is None:
230
  raise ValueError(
231
  error_message.format(idx=_idx) + instruction_code)
232
-
233
  super().__init__(params, defaults)
234
 
235
  self.debug = debug
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
243
  self.chunk_size = chunk_size
244
  self.use_distributed_muon = use_distributed_muon
245
  self.expert_keys = expert_keys
246
- self.cpu_offload = cpu_offload
247
- self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None
248
  self._offload_initialized = False
249
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
250
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
1008
  # D2H: offload optimizer states to CPU after computation.
1009
  if self.cpu_offload:
1010
  if not self._offload_initialized:
 
 
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
- # Checkpoint support for cpu_offload
1019
  # ------------------------------------------------------------------
1020
 
1021
- def state_dict(self) -> dict:
1022
- """Return optimizer state dict, reloading offloaded states first.
1023
-
1024
- When ``cpu_offload=True``, optimizer state tensors have their GPU
1025
- storage freed (``resize_(0)``) between steps. We reload them,
1026
- snapshot the state dict, then re-offload so the optimizer stays
1027
- in the expected post-step state. The returned dict holds cloned
1028
- tensors so they remain valid after the re-offload frees the
1029
- originals' GPU storage.
1030
- """
1031
- if self.cpu_offload and self._offload_initialized:
 
 
 
 
 
 
 
 
 
1032
  self._cpu_offload_pool.reload()
1033
  torch.cuda.current_stream().synchronize()
1034
- sd = super().state_dict()
1035
- if self.cpu_offload and self._offload_initialized:
1036
- # Clone state tensors so the returned dict survives re-offload
1037
- # (which frees GPU storage on the originals via resize_(0)).
1038
- for k in sd["state"]:
1039
- sd["state"][k] = {
1040
- sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
1041
- for sk, sv in sd["state"][k].items()
1042
- }
1043
- self._cpu_offload_pool.offload()
1044
- return sd
1045
 
1046
- def load_state_dict(self, state_dict: dict) -> None:
1047
- """Load optimizer state dict, then offload states if needed.
 
1048
 
1049
- After ``super().load_state_dict()`` populates GPU tensors, we
1050
- re-register them with the offload pool and offload to CPU so the
1051
- optimizer is in the same post-step state (GPU storage freed).
1052
- """
1053
- # If states were offloaded, reload first so storage sizes are
1054
- # correct for super().load_state_dict() to overwrite.
1055
- if self.cpu_offload and self._offload_initialized:
1056
- self._cpu_offload_pool.reload()
1057
- torch.cuda.current_stream().synchronize()
1058
 
 
 
 
1059
  super().load_state_dict(state_dict)
1060
 
1061
- if self.cpu_offload:
1062
- # Re-create the offload pool since state tensors may be new
1063
- # objects after load_state_dict.
1064
- self._cpu_offload_pool = CPUOffloadPool()
1065
- self._offload_initialized = False
1066
- self._register_states_for_offload()
1067
- self._offload_initialized = True
1068
- self._cpu_offload_pool.offload()
 
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
+ from .adamw import _placement_cache, _tensor_cache, step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
 
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
+ expert_keys=None):
 
211
  defaults = dict(
212
  lr=lr,
213
  weight_decay=weight_decay,
 
228
  if param_group.get("use_muon", None) is None:
229
  raise ValueError(
230
  error_message.format(idx=_idx) + instruction_code)
 
231
  super().__init__(params, defaults)
232
 
233
  self.debug = debug
 
241
  self.chunk_size = chunk_size
242
  self.use_distributed_muon = use_distributed_muon
243
  self.expert_keys = expert_keys
244
+ self.cpu_offload = False
245
+ self._cpu_offload_pool: CPUOffloadPool | None = None
246
  self._offload_initialized = False
247
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
248
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
 
1006
  # D2H: offload optimizer states to CPU after computation.
1007
  if self.cpu_offload:
1008
  if not self._offload_initialized:
1009
+ if self._cpu_offload_pool is None:
1010
+ self._cpu_offload_pool = CPUOffloadPool()
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
 
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
+ # CPU offload public helpers
1019
  # ------------------------------------------------------------------
1020
 
1021
+ def turn_on_cpu_offload(self):
1022
+ """Enable CPU offload for optimizer states."""
1023
+ if self.cpu_offload:
1024
+ return
1025
+ logger.info("[Muon] turn_on_cpu_offload")
1026
+ self.cpu_offload = True
1027
+ if not self.state:
1028
+ return
1029
+ self._cpu_offload_pool = CPUOffloadPool()
1030
+ self._offload_initialized = False
1031
+ self._register_states_for_offload()
1032
+ self._offload_initialized = True
1033
+ self._cpu_offload_pool.offload()
1034
+
1035
+ def turn_off_cpu_offload(self):
1036
+ """Disable CPU offload and keep optimizer states resident on GPU."""
1037
+ if not self.cpu_offload:
1038
+ return
1039
+ logger.info("[Muon] turn_off_cpu_offload")
1040
+ if self._offload_initialized:
1041
  self._cpu_offload_pool.reload()
1042
  torch.cuda.current_stream().synchronize()
1043
+ self._cpu_offload_pool = None
1044
+ self._offload_initialized = False
1045
+ self.cpu_offload = False
 
 
 
 
 
 
 
 
1046
 
1047
+ # ------------------------------------------------------------------
1048
+ # Checkpoint support for cpu_offload
1049
+ # ------------------------------------------------------------------
1050
 
1051
+ def state_dict(self) -> dict:
1052
+ if self.cpu_offload:
1053
+ raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
1054
+ return super().state_dict()
 
 
 
 
 
1055
 
1056
+ def load_state_dict(self, state_dict: dict) -> None:
1057
+ if self.cpu_offload:
1058
+ raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
1059
  super().load_state_dict(state_dict)
1060
 
1061
+ # Invalidate adamw.py's module-level tensor caches so that
1062
+ # the next step rebuilds them with the newly loaded state tensors.
1063
+ _placement_cache.clear()
1064
+ _tensor_cache.clear()
 
 
 
 
build/torch28-cxx11-cu126-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_5b58933_dirty
3
- ops = torch.ops._optimizer_5b58933_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_5b58933_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_b68ea5b_dirty
3
+ ops = torch.ops._optimizer_b68ea5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_b68ea5b_dirty::{op_name}"
build/torch28-cxx11-cu126-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:08e3ee2f567d7a89ba34a82429c2f47cdb17d53ef66dc7d5751cabeafa01ce0f
3
  size 1936664
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37bcda6973440cdeb880e411dbaf12220cef0bab18299b4922b6a504ab109b42
3
  size 1936664
build/torch28-cxx11-cu126-x86_64-linux/cpu_offload.py CHANGED
@@ -68,7 +68,11 @@ class CPUOffloadPool:
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
- self._storage_nbytes[tid] = local.untyped_storage().size()
 
 
 
 
72
  self._managed.append(tensor)
73
 
74
  # ------------------------------------------------------------------
@@ -89,7 +93,10 @@ class CPUOffloadPool:
89
  indices.append(idx)
90
  offsets.append((off, n))
91
  off += n
92
- cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
 
 
 
93
  self._groups[dtype] = {
94
  "indices": indices,
95
  "offsets": offsets,
@@ -117,8 +124,7 @@ class CPUOffloadPool:
117
  self._ensure_stream()
118
 
119
  # Offload stream waits for compute to finish.
120
- compute_event = torch.cuda.current_stream(
121
- self._device).record_event()
122
  self._offload_stream.wait_event(compute_event)
123
 
124
  offloaded_bytes = 0
@@ -134,15 +140,23 @@ class CPUOffloadPool:
134
  for i, mgd_idx in enumerate(indices):
135
  local = self._local(self._managed[mgd_idx])
136
  off, n = offsets[i]
137
- cpu_flat[off:off + n].copy_(
138
- local.reshape(-1), non_blocking=True)
139
 
140
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
141
 
142
  # Wait for all D2H copies to land, then free GPU storage.
143
  self._offload_stream.synchronize()
144
  for t in self._managed:
145
- self._local(t).untyped_storage().resize_(0)
 
 
 
 
 
 
 
 
146
 
147
  if not self._logged:
148
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
@@ -165,7 +179,14 @@ class CPUOffloadPool:
165
  # Re-allocate all GPU storages first.
166
  for t in self._managed:
167
  local = self._local(t)
168
- local.untyped_storage().resize_(self._storage_nbytes[id(t)])
 
 
 
 
 
 
 
169
 
170
  # Per-tensor H2D copies from CPU flat buffer slices.
171
  # non_blocking=True with pinned source allows DMA overlap.
@@ -177,8 +198,8 @@ class CPUOffloadPool:
177
  for i, mgd_idx in enumerate(indices):
178
  local = self._local(self._managed[mgd_idx])
179
  off, n = offsets[i]
180
- local.reshape(-1).copy_(
181
- cpu_flat[off:off + n], non_blocking=True)
182
 
183
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
184
 
 
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
+ storage = local.untyped_storage()
72
+ # Skip tensors with empty storage (e.g. empty FSDP shards)
73
+ if storage.size() == 0:
74
+ return
75
+ self._storage_nbytes[tid] = storage.size()
76
  self._managed.append(tensor)
77
 
78
  # ------------------------------------------------------------------
 
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
+ cpu_flat = torch.empty(off,
97
+ dtype=dtype,
98
+ device="cpu",
99
+ pin_memory=True)
100
  self._groups[dtype] = {
101
  "indices": indices,
102
  "offsets": offsets,
 
124
  self._ensure_stream()
125
 
126
  # Offload stream waits for compute to finish.
127
+ compute_event = torch.cuda.current_stream(self._device).record_event()
 
128
  self._offload_stream.wait_event(compute_event)
129
 
130
  offloaded_bytes = 0
 
140
  for i, mgd_idx in enumerate(indices):
141
  local = self._local(self._managed[mgd_idx])
142
  off, n = offsets[i]
143
+ cpu_flat[off:off + n].copy_(local.reshape(-1),
144
+ non_blocking=True)
145
 
146
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
147
 
148
  # Wait for all D2H copies to land, then free GPU storage.
149
  self._offload_stream.synchronize()
150
  for t in self._managed:
151
+ storage = self._local(t).untyped_storage()
152
+ if storage.size() != 0:
153
+ storage.resize_(0)
154
+ else:
155
+ raise RuntimeError(
156
+ f"Tensor storage is already freed (size=0) before offload. "
157
+ f"This indicates a double-free or external interference. "
158
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
159
+ )
160
 
161
  if not self._logged:
162
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
 
179
  # Re-allocate all GPU storages first.
180
  for t in self._managed:
181
  local = self._local(t)
182
+ storage = local.untyped_storage()
183
+ if storage.size() != 0:
184
+ raise RuntimeError(
185
+ f"Storage should have been freed (size=0) before reload, "
186
+ f"but got size={storage.size()}. "
187
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
188
+ )
189
+ storage.resize_(self._storage_nbytes[id(t)])
190
 
191
  # Per-tensor H2D copies from CPU flat buffer slices.
192
  # non_blocking=True with pinned source allows DMA overlap.
 
198
  for i, mgd_idx in enumerate(indices):
199
  local = self._local(self._managed[mgd_idx])
200
  off, n = offsets[i]
201
+ local.reshape(-1).copy_(cpu_flat[off:off + n],
202
+ non_blocking=True)
203
 
204
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
205
 
build/torch28-cxx11-cu126-x86_64-linux/distributed/utils.py CHANGED
@@ -163,9 +163,10 @@ def construct_shard_mesh(
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
- # This avoids a non-collective dist.new_group() call, which would
167
- # deadlock when only a subset of ranks call this function (e.g. expert
168
- # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately).
 
169
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
170
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
171
  if key not in _ranks_to_dist_cache:
@@ -207,22 +208,25 @@ def construct_shard_mesh(
207
  assert len(shard_placements) == len(set(shard_placements))
208
 
209
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
210
- # All ranks must call dist.new_group in the same order, even though each
211
- # rank only joins one group.
 
 
212
  def _cache_key(t: torch.Tensor) -> tuple:
213
  return (*t.shape, *t.flatten().tolist())
214
 
215
  my_key = None
216
  for sm in shard_meshes:
217
- key = _cache_key(sm)
218
  if (my_rank == sm).any().item():
 
219
  assert my_key is None, "Rank appears in multiple shard groups"
220
  my_key = key
221
- if key not in _ranks_to_dist_cache:
222
- pg = dist.new_group(sm.flatten().tolist())
223
- _ranks_to_dist_cache[key] = (
224
- DeviceMesh(device_type="cuda", mesh=sm),
225
- pg,
226
- )
 
227
 
228
  return (*_ranks_to_dist_cache[my_key], shard_placements)
 
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
+ # Reuses the mesh's existing ProcessGroup directly, avoiding the
167
+ # overhead of dist.new_group(). The standard path below also handles
168
+ # subset calls safely via use_local_synchronization=True, but this
169
+ # fast path is still beneficial for the common 1D shard case.
170
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
171
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
172
  if key not in _ranks_to_dist_cache:
 
208
  assert len(shard_placements) == len(set(shard_placements))
209
 
210
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
211
+ # Each rank only creates the group it belongs to, using
212
+ # use_local_synchronization=True so that only group members need to
213
+ # coordinate. This avoids deadlocks when different PP stages call
214
+ # construct_shard_mesh for different parameters.
215
  def _cache_key(t: torch.Tensor) -> tuple:
216
  return (*t.shape, *t.flatten().tolist())
217
 
218
  my_key = None
219
  for sm in shard_meshes:
 
220
  if (my_rank == sm).any().item():
221
+ key = _cache_key(sm)
222
  assert my_key is None, "Rank appears in multiple shard groups"
223
  my_key = key
224
+ if key not in _ranks_to_dist_cache:
225
+ pg = dist.new_group(sm.flatten().tolist(),
226
+ use_local_synchronization=True)
227
+ _ranks_to_dist_cache[key] = (
228
+ DeviceMesh(device_type="cuda", mesh=sm),
229
+ pg,
230
+ )
231
 
232
  return (*_ranks_to_dist_cache[my_key], shard_placements)
build/torch28-cxx11-cu126-x86_64-linux/muon.py CHANGED
@@ -8,7 +8,7 @@ import torch.distributed as dist
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
- from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
- expert_keys=None,
211
- cpu_offload=False):
212
  defaults = dict(
213
  lr=lr,
214
  weight_decay=weight_decay,
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
229
  if param_group.get("use_muon", None) is None:
230
  raise ValueError(
231
  error_message.format(idx=_idx) + instruction_code)
232
-
233
  super().__init__(params, defaults)
234
 
235
  self.debug = debug
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
243
  self.chunk_size = chunk_size
244
  self.use_distributed_muon = use_distributed_muon
245
  self.expert_keys = expert_keys
246
- self.cpu_offload = cpu_offload
247
- self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None
248
  self._offload_initialized = False
249
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
250
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
1008
  # D2H: offload optimizer states to CPU after computation.
1009
  if self.cpu_offload:
1010
  if not self._offload_initialized:
 
 
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
- # Checkpoint support for cpu_offload
1019
  # ------------------------------------------------------------------
1020
 
1021
- def state_dict(self) -> dict:
1022
- """Return optimizer state dict, reloading offloaded states first.
1023
-
1024
- When ``cpu_offload=True``, optimizer state tensors have their GPU
1025
- storage freed (``resize_(0)``) between steps. We reload them,
1026
- snapshot the state dict, then re-offload so the optimizer stays
1027
- in the expected post-step state. The returned dict holds cloned
1028
- tensors so they remain valid after the re-offload frees the
1029
- originals' GPU storage.
1030
- """
1031
- if self.cpu_offload and self._offload_initialized:
 
 
 
 
 
 
 
 
 
1032
  self._cpu_offload_pool.reload()
1033
  torch.cuda.current_stream().synchronize()
1034
- sd = super().state_dict()
1035
- if self.cpu_offload and self._offload_initialized:
1036
- # Clone state tensors so the returned dict survives re-offload
1037
- # (which frees GPU storage on the originals via resize_(0)).
1038
- for k in sd["state"]:
1039
- sd["state"][k] = {
1040
- sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
1041
- for sk, sv in sd["state"][k].items()
1042
- }
1043
- self._cpu_offload_pool.offload()
1044
- return sd
1045
 
1046
- def load_state_dict(self, state_dict: dict) -> None:
1047
- """Load optimizer state dict, then offload states if needed.
 
1048
 
1049
- After ``super().load_state_dict()`` populates GPU tensors, we
1050
- re-register them with the offload pool and offload to CPU so the
1051
- optimizer is in the same post-step state (GPU storage freed).
1052
- """
1053
- # If states were offloaded, reload first so storage sizes are
1054
- # correct for super().load_state_dict() to overwrite.
1055
- if self.cpu_offload and self._offload_initialized:
1056
- self._cpu_offload_pool.reload()
1057
- torch.cuda.current_stream().synchronize()
1058
 
 
 
 
1059
  super().load_state_dict(state_dict)
1060
 
1061
- if self.cpu_offload:
1062
- # Re-create the offload pool since state tensors may be new
1063
- # objects after load_state_dict.
1064
- self._cpu_offload_pool = CPUOffloadPool()
1065
- self._offload_initialized = False
1066
- self._register_states_for_offload()
1067
- self._offload_initialized = True
1068
- self._cpu_offload_pool.offload()
 
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
+ from .adamw import _placement_cache, _tensor_cache, step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
 
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
+ expert_keys=None):
 
211
  defaults = dict(
212
  lr=lr,
213
  weight_decay=weight_decay,
 
228
  if param_group.get("use_muon", None) is None:
229
  raise ValueError(
230
  error_message.format(idx=_idx) + instruction_code)
 
231
  super().__init__(params, defaults)
232
 
233
  self.debug = debug
 
241
  self.chunk_size = chunk_size
242
  self.use_distributed_muon = use_distributed_muon
243
  self.expert_keys = expert_keys
244
+ self.cpu_offload = False
245
+ self._cpu_offload_pool: CPUOffloadPool | None = None
246
  self._offload_initialized = False
247
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
248
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
 
1006
  # D2H: offload optimizer states to CPU after computation.
1007
  if self.cpu_offload:
1008
  if not self._offload_initialized:
1009
+ if self._cpu_offload_pool is None:
1010
+ self._cpu_offload_pool = CPUOffloadPool()
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
 
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
+ # CPU offload public helpers
1019
  # ------------------------------------------------------------------
1020
 
1021
+ def turn_on_cpu_offload(self):
1022
+ """Enable CPU offload for optimizer states."""
1023
+ if self.cpu_offload:
1024
+ return
1025
+ logger.info("[Muon] turn_on_cpu_offload")
1026
+ self.cpu_offload = True
1027
+ if not self.state:
1028
+ return
1029
+ self._cpu_offload_pool = CPUOffloadPool()
1030
+ self._offload_initialized = False
1031
+ self._register_states_for_offload()
1032
+ self._offload_initialized = True
1033
+ self._cpu_offload_pool.offload()
1034
+
1035
+ def turn_off_cpu_offload(self):
1036
+ """Disable CPU offload and keep optimizer states resident on GPU."""
1037
+ if not self.cpu_offload:
1038
+ return
1039
+ logger.info("[Muon] turn_off_cpu_offload")
1040
+ if self._offload_initialized:
1041
  self._cpu_offload_pool.reload()
1042
  torch.cuda.current_stream().synchronize()
1043
+ self._cpu_offload_pool = None
1044
+ self._offload_initialized = False
1045
+ self.cpu_offload = False
 
 
 
 
 
 
 
 
1046
 
1047
+ # ------------------------------------------------------------------
1048
+ # Checkpoint support for cpu_offload
1049
+ # ------------------------------------------------------------------
1050
 
1051
+ def state_dict(self) -> dict:
1052
+ if self.cpu_offload:
1053
+ raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
1054
+ return super().state_dict()
 
 
 
 
 
1055
 
1056
+ def load_state_dict(self, state_dict: dict) -> None:
1057
+ if self.cpu_offload:
1058
+ raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
1059
  super().load_state_dict(state_dict)
1060
 
1061
+ # Invalidate adamw.py's module-level tensor caches so that
1062
+ # the next step rebuilds them with the newly loaded state tensors.
1063
+ _placement_cache.clear()
1064
+ _tensor_cache.clear()
 
 
 
 
build/torch28-cxx11-cu128-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_5b58933_dirty
3
- ops = torch.ops._optimizer_5b58933_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_5b58933_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_b68ea5b_dirty
3
+ ops = torch.ops._optimizer_b68ea5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_b68ea5b_dirty::{op_name}"
build/torch28-cxx11-cu128-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e08baad750646c67f23c6e7c4d0e1b7266eeffed3bbb730729ba8f37e120a2b1
3
  size 1999872
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9cd35014d41034ed35fbad31c19c80e7b3977cea889a865eb12db705678bb29
3
  size 1999872
build/torch28-cxx11-cu128-x86_64-linux/cpu_offload.py CHANGED
@@ -68,7 +68,11 @@ class CPUOffloadPool:
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
- self._storage_nbytes[tid] = local.untyped_storage().size()
 
 
 
 
72
  self._managed.append(tensor)
73
 
74
  # ------------------------------------------------------------------
@@ -89,7 +93,10 @@ class CPUOffloadPool:
89
  indices.append(idx)
90
  offsets.append((off, n))
91
  off += n
92
- cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
 
 
 
93
  self._groups[dtype] = {
94
  "indices": indices,
95
  "offsets": offsets,
@@ -117,8 +124,7 @@ class CPUOffloadPool:
117
  self._ensure_stream()
118
 
119
  # Offload stream waits for compute to finish.
120
- compute_event = torch.cuda.current_stream(
121
- self._device).record_event()
122
  self._offload_stream.wait_event(compute_event)
123
 
124
  offloaded_bytes = 0
@@ -134,15 +140,23 @@ class CPUOffloadPool:
134
  for i, mgd_idx in enumerate(indices):
135
  local = self._local(self._managed[mgd_idx])
136
  off, n = offsets[i]
137
- cpu_flat[off:off + n].copy_(
138
- local.reshape(-1), non_blocking=True)
139
 
140
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
141
 
142
  # Wait for all D2H copies to land, then free GPU storage.
143
  self._offload_stream.synchronize()
144
  for t in self._managed:
145
- self._local(t).untyped_storage().resize_(0)
 
 
 
 
 
 
 
 
146
 
147
  if not self._logged:
148
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
@@ -165,7 +179,14 @@ class CPUOffloadPool:
165
  # Re-allocate all GPU storages first.
166
  for t in self._managed:
167
  local = self._local(t)
168
- local.untyped_storage().resize_(self._storage_nbytes[id(t)])
 
 
 
 
 
 
 
169
 
170
  # Per-tensor H2D copies from CPU flat buffer slices.
171
  # non_blocking=True with pinned source allows DMA overlap.
@@ -177,8 +198,8 @@ class CPUOffloadPool:
177
  for i, mgd_idx in enumerate(indices):
178
  local = self._local(self._managed[mgd_idx])
179
  off, n = offsets[i]
180
- local.reshape(-1).copy_(
181
- cpu_flat[off:off + n], non_blocking=True)
182
 
183
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
184
 
 
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
+ storage = local.untyped_storage()
72
+ # Skip tensors with empty storage (e.g. empty FSDP shards)
73
+ if storage.size() == 0:
74
+ return
75
+ self._storage_nbytes[tid] = storage.size()
76
  self._managed.append(tensor)
77
 
78
  # ------------------------------------------------------------------
 
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
+ cpu_flat = torch.empty(off,
97
+ dtype=dtype,
98
+ device="cpu",
99
+ pin_memory=True)
100
  self._groups[dtype] = {
101
  "indices": indices,
102
  "offsets": offsets,
 
124
  self._ensure_stream()
125
 
126
  # Offload stream waits for compute to finish.
127
+ compute_event = torch.cuda.current_stream(self._device).record_event()
 
128
  self._offload_stream.wait_event(compute_event)
129
 
130
  offloaded_bytes = 0
 
140
  for i, mgd_idx in enumerate(indices):
141
  local = self._local(self._managed[mgd_idx])
142
  off, n = offsets[i]
143
+ cpu_flat[off:off + n].copy_(local.reshape(-1),
144
+ non_blocking=True)
145
 
146
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
147
 
148
  # Wait for all D2H copies to land, then free GPU storage.
149
  self._offload_stream.synchronize()
150
  for t in self._managed:
151
+ storage = self._local(t).untyped_storage()
152
+ if storage.size() != 0:
153
+ storage.resize_(0)
154
+ else:
155
+ raise RuntimeError(
156
+ f"Tensor storage is already freed (size=0) before offload. "
157
+ f"This indicates a double-free or external interference. "
158
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
159
+ )
160
 
161
  if not self._logged:
162
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
 
179
  # Re-allocate all GPU storages first.
180
  for t in self._managed:
181
  local = self._local(t)
182
+ storage = local.untyped_storage()
183
+ if storage.size() != 0:
184
+ raise RuntimeError(
185
+ f"Storage should have been freed (size=0) before reload, "
186
+ f"but got size={storage.size()}. "
187
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
188
+ )
189
+ storage.resize_(self._storage_nbytes[id(t)])
190
 
191
  # Per-tensor H2D copies from CPU flat buffer slices.
192
  # non_blocking=True with pinned source allows DMA overlap.
 
198
  for i, mgd_idx in enumerate(indices):
199
  local = self._local(self._managed[mgd_idx])
200
  off, n = offsets[i]
201
+ local.reshape(-1).copy_(cpu_flat[off:off + n],
202
+ non_blocking=True)
203
 
204
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
205
 
build/torch28-cxx11-cu128-x86_64-linux/distributed/utils.py CHANGED
@@ -163,9 +163,10 @@ def construct_shard_mesh(
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
- # This avoids a non-collective dist.new_group() call, which would
167
- # deadlock when only a subset of ranks call this function (e.g. expert
168
- # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately).
 
169
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
170
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
171
  if key not in _ranks_to_dist_cache:
@@ -207,22 +208,25 @@ def construct_shard_mesh(
207
  assert len(shard_placements) == len(set(shard_placements))
208
 
209
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
210
- # All ranks must call dist.new_group in the same order, even though each
211
- # rank only joins one group.
 
 
212
  def _cache_key(t: torch.Tensor) -> tuple:
213
  return (*t.shape, *t.flatten().tolist())
214
 
215
  my_key = None
216
  for sm in shard_meshes:
217
- key = _cache_key(sm)
218
  if (my_rank == sm).any().item():
 
219
  assert my_key is None, "Rank appears in multiple shard groups"
220
  my_key = key
221
- if key not in _ranks_to_dist_cache:
222
- pg = dist.new_group(sm.flatten().tolist())
223
- _ranks_to_dist_cache[key] = (
224
- DeviceMesh(device_type="cuda", mesh=sm),
225
- pg,
226
- )
 
227
 
228
  return (*_ranks_to_dist_cache[my_key], shard_placements)
 
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
+ # Reuses the mesh's existing ProcessGroup directly, avoiding the
167
+ # overhead of dist.new_group(). The standard path below also handles
168
+ # subset calls safely via use_local_synchronization=True, but this
169
+ # fast path is still beneficial for the common 1D shard case.
170
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
171
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
172
  if key not in _ranks_to_dist_cache:
 
208
  assert len(shard_placements) == len(set(shard_placements))
209
 
210
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
211
+ # Each rank only creates the group it belongs to, using
212
+ # use_local_synchronization=True so that only group members need to
213
+ # coordinate. This avoids deadlocks when different PP stages call
214
+ # construct_shard_mesh for different parameters.
215
  def _cache_key(t: torch.Tensor) -> tuple:
216
  return (*t.shape, *t.flatten().tolist())
217
 
218
  my_key = None
219
  for sm in shard_meshes:
 
220
  if (my_rank == sm).any().item():
221
+ key = _cache_key(sm)
222
  assert my_key is None, "Rank appears in multiple shard groups"
223
  my_key = key
224
+ if key not in _ranks_to_dist_cache:
225
+ pg = dist.new_group(sm.flatten().tolist(),
226
+ use_local_synchronization=True)
227
+ _ranks_to_dist_cache[key] = (
228
+ DeviceMesh(device_type="cuda", mesh=sm),
229
+ pg,
230
+ )
231
 
232
  return (*_ranks_to_dist_cache[my_key], shard_placements)
build/torch28-cxx11-cu128-x86_64-linux/muon.py CHANGED
@@ -8,7 +8,7 @@ import torch.distributed as dist
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
- from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
- expert_keys=None,
211
- cpu_offload=False):
212
  defaults = dict(
213
  lr=lr,
214
  weight_decay=weight_decay,
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
229
  if param_group.get("use_muon", None) is None:
230
  raise ValueError(
231
  error_message.format(idx=_idx) + instruction_code)
232
-
233
  super().__init__(params, defaults)
234
 
235
  self.debug = debug
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
243
  self.chunk_size = chunk_size
244
  self.use_distributed_muon = use_distributed_muon
245
  self.expert_keys = expert_keys
246
- self.cpu_offload = cpu_offload
247
- self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None
248
  self._offload_initialized = False
249
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
250
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
1008
  # D2H: offload optimizer states to CPU after computation.
1009
  if self.cpu_offload:
1010
  if not self._offload_initialized:
 
 
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
- # Checkpoint support for cpu_offload
1019
  # ------------------------------------------------------------------
1020
 
1021
- def state_dict(self) -> dict:
1022
- """Return optimizer state dict, reloading offloaded states first.
1023
-
1024
- When ``cpu_offload=True``, optimizer state tensors have their GPU
1025
- storage freed (``resize_(0)``) between steps. We reload them,
1026
- snapshot the state dict, then re-offload so the optimizer stays
1027
- in the expected post-step state. The returned dict holds cloned
1028
- tensors so they remain valid after the re-offload frees the
1029
- originals' GPU storage.
1030
- """
1031
- if self.cpu_offload and self._offload_initialized:
 
 
 
 
 
 
 
 
 
1032
  self._cpu_offload_pool.reload()
1033
  torch.cuda.current_stream().synchronize()
1034
- sd = super().state_dict()
1035
- if self.cpu_offload and self._offload_initialized:
1036
- # Clone state tensors so the returned dict survives re-offload
1037
- # (which frees GPU storage on the originals via resize_(0)).
1038
- for k in sd["state"]:
1039
- sd["state"][k] = {
1040
- sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
1041
- for sk, sv in sd["state"][k].items()
1042
- }
1043
- self._cpu_offload_pool.offload()
1044
- return sd
1045
 
1046
- def load_state_dict(self, state_dict: dict) -> None:
1047
- """Load optimizer state dict, then offload states if needed.
 
1048
 
1049
- After ``super().load_state_dict()`` populates GPU tensors, we
1050
- re-register them with the offload pool and offload to CPU so the
1051
- optimizer is in the same post-step state (GPU storage freed).
1052
- """
1053
- # If states were offloaded, reload first so storage sizes are
1054
- # correct for super().load_state_dict() to overwrite.
1055
- if self.cpu_offload and self._offload_initialized:
1056
- self._cpu_offload_pool.reload()
1057
- torch.cuda.current_stream().synchronize()
1058
 
 
 
 
1059
  super().load_state_dict(state_dict)
1060
 
1061
- if self.cpu_offload:
1062
- # Re-create the offload pool since state tensors may be new
1063
- # objects after load_state_dict.
1064
- self._cpu_offload_pool = CPUOffloadPool()
1065
- self._offload_initialized = False
1066
- self._register_states_for_offload()
1067
- self._offload_initialized = True
1068
- self._cpu_offload_pool.offload()
 
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
+ from .adamw import _placement_cache, _tensor_cache, step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
 
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
+ expert_keys=None):
 
211
  defaults = dict(
212
  lr=lr,
213
  weight_decay=weight_decay,
 
228
  if param_group.get("use_muon", None) is None:
229
  raise ValueError(
230
  error_message.format(idx=_idx) + instruction_code)
 
231
  super().__init__(params, defaults)
232
 
233
  self.debug = debug
 
241
  self.chunk_size = chunk_size
242
  self.use_distributed_muon = use_distributed_muon
243
  self.expert_keys = expert_keys
244
+ self.cpu_offload = False
245
+ self._cpu_offload_pool: CPUOffloadPool | None = None
246
  self._offload_initialized = False
247
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
248
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
 
1006
  # D2H: offload optimizer states to CPU after computation.
1007
  if self.cpu_offload:
1008
  if not self._offload_initialized:
1009
+ if self._cpu_offload_pool is None:
1010
+ self._cpu_offload_pool = CPUOffloadPool()
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
 
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
+ # CPU offload public helpers
1019
  # ------------------------------------------------------------------
1020
 
1021
+ def turn_on_cpu_offload(self):
1022
+ """Enable CPU offload for optimizer states."""
1023
+ if self.cpu_offload:
1024
+ return
1025
+ logger.info("[Muon] turn_on_cpu_offload")
1026
+ self.cpu_offload = True
1027
+ if not self.state:
1028
+ return
1029
+ self._cpu_offload_pool = CPUOffloadPool()
1030
+ self._offload_initialized = False
1031
+ self._register_states_for_offload()
1032
+ self._offload_initialized = True
1033
+ self._cpu_offload_pool.offload()
1034
+
1035
+ def turn_off_cpu_offload(self):
1036
+ """Disable CPU offload and keep optimizer states resident on GPU."""
1037
+ if not self.cpu_offload:
1038
+ return
1039
+ logger.info("[Muon] turn_off_cpu_offload")
1040
+ if self._offload_initialized:
1041
  self._cpu_offload_pool.reload()
1042
  torch.cuda.current_stream().synchronize()
1043
+ self._cpu_offload_pool = None
1044
+ self._offload_initialized = False
1045
+ self.cpu_offload = False
 
 
 
 
 
 
 
 
1046
 
1047
+ # ------------------------------------------------------------------
1048
+ # Checkpoint support for cpu_offload
1049
+ # ------------------------------------------------------------------
1050
 
1051
+ def state_dict(self) -> dict:
1052
+ if self.cpu_offload:
1053
+ raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
1054
+ return super().state_dict()
 
 
 
 
 
1055
 
1056
+ def load_state_dict(self, state_dict: dict) -> None:
1057
+ if self.cpu_offload:
1058
+ raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
1059
  super().load_state_dict(state_dict)
1060
 
1061
+ # Invalidate adamw.py's module-level tensor caches so that
1062
+ # the next step rebuilds them with the newly loaded state tensors.
1063
+ _placement_cache.clear()
1064
+ _tensor_cache.clear()
 
 
 
 
build/torch28-cxx11-cu129-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_5b58933_dirty
3
- ops = torch.ops._optimizer_5b58933_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_5b58933_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_b68ea5b_dirty
3
+ ops = torch.ops._optimizer_b68ea5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_b68ea5b_dirty::{op_name}"
build/torch28-cxx11-cu129-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c485caa290f4b43e49db4ceafe25f0d0840dcdd61d02a5aecfa78d8f9acc9b60
3
  size 1999872
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64072037c62afbfefe92a35b026794f2bd406fddef38bf58d318d3bae7652a29
3
  size 1999872
build/torch28-cxx11-cu129-x86_64-linux/cpu_offload.py CHANGED
@@ -68,7 +68,11 @@ class CPUOffloadPool:
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
- self._storage_nbytes[tid] = local.untyped_storage().size()
 
 
 
 
72
  self._managed.append(tensor)
73
 
74
  # ------------------------------------------------------------------
@@ -89,7 +93,10 @@ class CPUOffloadPool:
89
  indices.append(idx)
90
  offsets.append((off, n))
91
  off += n
92
- cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
 
 
 
93
  self._groups[dtype] = {
94
  "indices": indices,
95
  "offsets": offsets,
@@ -117,8 +124,7 @@ class CPUOffloadPool:
117
  self._ensure_stream()
118
 
119
  # Offload stream waits for compute to finish.
120
- compute_event = torch.cuda.current_stream(
121
- self._device).record_event()
122
  self._offload_stream.wait_event(compute_event)
123
 
124
  offloaded_bytes = 0
@@ -134,15 +140,23 @@ class CPUOffloadPool:
134
  for i, mgd_idx in enumerate(indices):
135
  local = self._local(self._managed[mgd_idx])
136
  off, n = offsets[i]
137
- cpu_flat[off:off + n].copy_(
138
- local.reshape(-1), non_blocking=True)
139
 
140
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
141
 
142
  # Wait for all D2H copies to land, then free GPU storage.
143
  self._offload_stream.synchronize()
144
  for t in self._managed:
145
- self._local(t).untyped_storage().resize_(0)
 
 
 
 
 
 
 
 
146
 
147
  if not self._logged:
148
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
@@ -165,7 +179,14 @@ class CPUOffloadPool:
165
  # Re-allocate all GPU storages first.
166
  for t in self._managed:
167
  local = self._local(t)
168
- local.untyped_storage().resize_(self._storage_nbytes[id(t)])
 
 
 
 
 
 
 
169
 
170
  # Per-tensor H2D copies from CPU flat buffer slices.
171
  # non_blocking=True with pinned source allows DMA overlap.
@@ -177,8 +198,8 @@ class CPUOffloadPool:
177
  for i, mgd_idx in enumerate(indices):
178
  local = self._local(self._managed[mgd_idx])
179
  off, n = offsets[i]
180
- local.reshape(-1).copy_(
181
- cpu_flat[off:off + n], non_blocking=True)
182
 
183
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
184
 
 
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
+ storage = local.untyped_storage()
72
+ # Skip tensors with empty storage (e.g. empty FSDP shards)
73
+ if storage.size() == 0:
74
+ return
75
+ self._storage_nbytes[tid] = storage.size()
76
  self._managed.append(tensor)
77
 
78
  # ------------------------------------------------------------------
 
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
+ cpu_flat = torch.empty(off,
97
+ dtype=dtype,
98
+ device="cpu",
99
+ pin_memory=True)
100
  self._groups[dtype] = {
101
  "indices": indices,
102
  "offsets": offsets,
 
124
  self._ensure_stream()
125
 
126
  # Offload stream waits for compute to finish.
127
+ compute_event = torch.cuda.current_stream(self._device).record_event()
 
128
  self._offload_stream.wait_event(compute_event)
129
 
130
  offloaded_bytes = 0
 
140
  for i, mgd_idx in enumerate(indices):
141
  local = self._local(self._managed[mgd_idx])
142
  off, n = offsets[i]
143
+ cpu_flat[off:off + n].copy_(local.reshape(-1),
144
+ non_blocking=True)
145
 
146
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
147
 
148
  # Wait for all D2H copies to land, then free GPU storage.
149
  self._offload_stream.synchronize()
150
  for t in self._managed:
151
+ storage = self._local(t).untyped_storage()
152
+ if storage.size() != 0:
153
+ storage.resize_(0)
154
+ else:
155
+ raise RuntimeError(
156
+ f"Tensor storage is already freed (size=0) before offload. "
157
+ f"This indicates a double-free or external interference. "
158
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
159
+ )
160
 
161
  if not self._logged:
162
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
 
179
  # Re-allocate all GPU storages first.
180
  for t in self._managed:
181
  local = self._local(t)
182
+ storage = local.untyped_storage()
183
+ if storage.size() != 0:
184
+ raise RuntimeError(
185
+ f"Storage should have been freed (size=0) before reload, "
186
+ f"but got size={storage.size()}. "
187
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
188
+ )
189
+ storage.resize_(self._storage_nbytes[id(t)])
190
 
191
  # Per-tensor H2D copies from CPU flat buffer slices.
192
  # non_blocking=True with pinned source allows DMA overlap.
 
198
  for i, mgd_idx in enumerate(indices):
199
  local = self._local(self._managed[mgd_idx])
200
  off, n = offsets[i]
201
+ local.reshape(-1).copy_(cpu_flat[off:off + n],
202
+ non_blocking=True)
203
 
204
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
205
 
build/torch28-cxx11-cu129-x86_64-linux/distributed/utils.py CHANGED
@@ -163,9 +163,10 @@ def construct_shard_mesh(
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
- # This avoids a non-collective dist.new_group() call, which would
167
- # deadlock when only a subset of ranks call this function (e.g. expert
168
- # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately).
 
169
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
170
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
171
  if key not in _ranks_to_dist_cache:
@@ -207,22 +208,25 @@ def construct_shard_mesh(
207
  assert len(shard_placements) == len(set(shard_placements))
208
 
209
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
210
- # All ranks must call dist.new_group in the same order, even though each
211
- # rank only joins one group.
 
 
212
  def _cache_key(t: torch.Tensor) -> tuple:
213
  return (*t.shape, *t.flatten().tolist())
214
 
215
  my_key = None
216
  for sm in shard_meshes:
217
- key = _cache_key(sm)
218
  if (my_rank == sm).any().item():
 
219
  assert my_key is None, "Rank appears in multiple shard groups"
220
  my_key = key
221
- if key not in _ranks_to_dist_cache:
222
- pg = dist.new_group(sm.flatten().tolist())
223
- _ranks_to_dist_cache[key] = (
224
- DeviceMesh(device_type="cuda", mesh=sm),
225
- pg,
226
- )
 
227
 
228
  return (*_ranks_to_dist_cache[my_key], shard_placements)
 
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
+ # Reuses the mesh's existing ProcessGroup directly, avoiding the
167
+ # overhead of dist.new_group(). The standard path below also handles
168
+ # subset calls safely via use_local_synchronization=True, but this
169
+ # fast path is still beneficial for the common 1D shard case.
170
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
171
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
172
  if key not in _ranks_to_dist_cache:
 
208
  assert len(shard_placements) == len(set(shard_placements))
209
 
210
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
211
+ # Each rank only creates the group it belongs to, using
212
+ # use_local_synchronization=True so that only group members need to
213
+ # coordinate. This avoids deadlocks when different PP stages call
214
+ # construct_shard_mesh for different parameters.
215
  def _cache_key(t: torch.Tensor) -> tuple:
216
  return (*t.shape, *t.flatten().tolist())
217
 
218
  my_key = None
219
  for sm in shard_meshes:
 
220
  if (my_rank == sm).any().item():
221
+ key = _cache_key(sm)
222
  assert my_key is None, "Rank appears in multiple shard groups"
223
  my_key = key
224
+ if key not in _ranks_to_dist_cache:
225
+ pg = dist.new_group(sm.flatten().tolist(),
226
+ use_local_synchronization=True)
227
+ _ranks_to_dist_cache[key] = (
228
+ DeviceMesh(device_type="cuda", mesh=sm),
229
+ pg,
230
+ )
231
 
232
  return (*_ranks_to_dist_cache[my_key], shard_placements)
build/torch28-cxx11-cu129-x86_64-linux/muon.py CHANGED
@@ -8,7 +8,7 @@ import torch.distributed as dist
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
- from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
- expert_keys=None,
211
- cpu_offload=False):
212
  defaults = dict(
213
  lr=lr,
214
  weight_decay=weight_decay,
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
229
  if param_group.get("use_muon", None) is None:
230
  raise ValueError(
231
  error_message.format(idx=_idx) + instruction_code)
232
-
233
  super().__init__(params, defaults)
234
 
235
  self.debug = debug
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
243
  self.chunk_size = chunk_size
244
  self.use_distributed_muon = use_distributed_muon
245
  self.expert_keys = expert_keys
246
- self.cpu_offload = cpu_offload
247
- self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None
248
  self._offload_initialized = False
249
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
250
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
1008
  # D2H: offload optimizer states to CPU after computation.
1009
  if self.cpu_offload:
1010
  if not self._offload_initialized:
 
 
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
- # Checkpoint support for cpu_offload
1019
  # ------------------------------------------------------------------
1020
 
1021
- def state_dict(self) -> dict:
1022
- """Return optimizer state dict, reloading offloaded states first.
1023
-
1024
- When ``cpu_offload=True``, optimizer state tensors have their GPU
1025
- storage freed (``resize_(0)``) between steps. We reload them,
1026
- snapshot the state dict, then re-offload so the optimizer stays
1027
- in the expected post-step state. The returned dict holds cloned
1028
- tensors so they remain valid after the re-offload frees the
1029
- originals' GPU storage.
1030
- """
1031
- if self.cpu_offload and self._offload_initialized:
 
 
 
 
 
 
 
 
 
1032
  self._cpu_offload_pool.reload()
1033
  torch.cuda.current_stream().synchronize()
1034
- sd = super().state_dict()
1035
- if self.cpu_offload and self._offload_initialized:
1036
- # Clone state tensors so the returned dict survives re-offload
1037
- # (which frees GPU storage on the originals via resize_(0)).
1038
- for k in sd["state"]:
1039
- sd["state"][k] = {
1040
- sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
1041
- for sk, sv in sd["state"][k].items()
1042
- }
1043
- self._cpu_offload_pool.offload()
1044
- return sd
1045
 
1046
- def load_state_dict(self, state_dict: dict) -> None:
1047
- """Load optimizer state dict, then offload states if needed.
 
1048
 
1049
- After ``super().load_state_dict()`` populates GPU tensors, we
1050
- re-register them with the offload pool and offload to CPU so the
1051
- optimizer is in the same post-step state (GPU storage freed).
1052
- """
1053
- # If states were offloaded, reload first so storage sizes are
1054
- # correct for super().load_state_dict() to overwrite.
1055
- if self.cpu_offload and self._offload_initialized:
1056
- self._cpu_offload_pool.reload()
1057
- torch.cuda.current_stream().synchronize()
1058
 
 
 
 
1059
  super().load_state_dict(state_dict)
1060
 
1061
- if self.cpu_offload:
1062
- # Re-create the offload pool since state tensors may be new
1063
- # objects after load_state_dict.
1064
- self._cpu_offload_pool = CPUOffloadPool()
1065
- self._offload_initialized = False
1066
- self._register_states_for_offload()
1067
- self._offload_initialized = True
1068
- self._cpu_offload_pool.offload()
 
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
+ from .adamw import _placement_cache, _tensor_cache, step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
 
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
+ expert_keys=None):
 
211
  defaults = dict(
212
  lr=lr,
213
  weight_decay=weight_decay,
 
228
  if param_group.get("use_muon", None) is None:
229
  raise ValueError(
230
  error_message.format(idx=_idx) + instruction_code)
 
231
  super().__init__(params, defaults)
232
 
233
  self.debug = debug
 
241
  self.chunk_size = chunk_size
242
  self.use_distributed_muon = use_distributed_muon
243
  self.expert_keys = expert_keys
244
+ self.cpu_offload = False
245
+ self._cpu_offload_pool: CPUOffloadPool | None = None
246
  self._offload_initialized = False
247
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
248
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
 
1006
  # D2H: offload optimizer states to CPU after computation.
1007
  if self.cpu_offload:
1008
  if not self._offload_initialized:
1009
+ if self._cpu_offload_pool is None:
1010
+ self._cpu_offload_pool = CPUOffloadPool()
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
 
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
+ # CPU offload public helpers
1019
  # ------------------------------------------------------------------
1020
 
1021
+ def turn_on_cpu_offload(self):
1022
+ """Enable CPU offload for optimizer states."""
1023
+ if self.cpu_offload:
1024
+ return
1025
+ logger.info("[Muon] turn_on_cpu_offload")
1026
+ self.cpu_offload = True
1027
+ if not self.state:
1028
+ return
1029
+ self._cpu_offload_pool = CPUOffloadPool()
1030
+ self._offload_initialized = False
1031
+ self._register_states_for_offload()
1032
+ self._offload_initialized = True
1033
+ self._cpu_offload_pool.offload()
1034
+
1035
+ def turn_off_cpu_offload(self):
1036
+ """Disable CPU offload and keep optimizer states resident on GPU."""
1037
+ if not self.cpu_offload:
1038
+ return
1039
+ logger.info("[Muon] turn_off_cpu_offload")
1040
+ if self._offload_initialized:
1041
  self._cpu_offload_pool.reload()
1042
  torch.cuda.current_stream().synchronize()
1043
+ self._cpu_offload_pool = None
1044
+ self._offload_initialized = False
1045
+ self.cpu_offload = False
 
 
 
 
 
 
 
 
1046
 
1047
+ # ------------------------------------------------------------------
1048
+ # Checkpoint support for cpu_offload
1049
+ # ------------------------------------------------------------------
1050
 
1051
+ def state_dict(self) -> dict:
1052
+ if self.cpu_offload:
1053
+ raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
1054
+ return super().state_dict()
 
 
 
 
 
1055
 
1056
+ def load_state_dict(self, state_dict: dict) -> None:
1057
+ if self.cpu_offload:
1058
+ raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
1059
  super().load_state_dict(state_dict)
1060
 
1061
+ # Invalidate adamw.py's module-level tensor caches so that
1062
+ # the next step rebuilds them with the newly loaded state tensors.
1063
+ _placement_cache.clear()
1064
+ _tensor_cache.clear()
 
 
 
 
build/torch28-cxx11-rocm63-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_5b58933_dirty
3
- ops = torch.ops._optimizer_5b58933_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_5b58933_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_b68ea5b_dirty
3
+ ops = torch.ops._optimizer_b68ea5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_b68ea5b_dirty::{op_name}"
build/torch28-cxx11-rocm63-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9a57395ef49976af61778f127cfdeace6a4c35b491b9903e48b1cd7199ee217c
3
  size 1865080
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65135fe756ed97f2bd21fefda883b6a7b90179ebd7c0a882673239daf9d9aa6a
3
  size 1865080
build/torch28-cxx11-rocm63-x86_64-linux/cpu_offload.py CHANGED
@@ -68,7 +68,11 @@ class CPUOffloadPool:
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
- self._storage_nbytes[tid] = local.untyped_storage().size()
 
 
 
 
72
  self._managed.append(tensor)
73
 
74
  # ------------------------------------------------------------------
@@ -89,7 +93,10 @@ class CPUOffloadPool:
89
  indices.append(idx)
90
  offsets.append((off, n))
91
  off += n
92
- cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
 
 
 
93
  self._groups[dtype] = {
94
  "indices": indices,
95
  "offsets": offsets,
@@ -117,8 +124,7 @@ class CPUOffloadPool:
117
  self._ensure_stream()
118
 
119
  # Offload stream waits for compute to finish.
120
- compute_event = torch.cuda.current_stream(
121
- self._device).record_event()
122
  self._offload_stream.wait_event(compute_event)
123
 
124
  offloaded_bytes = 0
@@ -134,15 +140,23 @@ class CPUOffloadPool:
134
  for i, mgd_idx in enumerate(indices):
135
  local = self._local(self._managed[mgd_idx])
136
  off, n = offsets[i]
137
- cpu_flat[off:off + n].copy_(
138
- local.reshape(-1), non_blocking=True)
139
 
140
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
141
 
142
  # Wait for all D2H copies to land, then free GPU storage.
143
  self._offload_stream.synchronize()
144
  for t in self._managed:
145
- self._local(t).untyped_storage().resize_(0)
 
 
 
 
 
 
 
 
146
 
147
  if not self._logged:
148
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
@@ -165,7 +179,14 @@ class CPUOffloadPool:
165
  # Re-allocate all GPU storages first.
166
  for t in self._managed:
167
  local = self._local(t)
168
- local.untyped_storage().resize_(self._storage_nbytes[id(t)])
 
 
 
 
 
 
 
169
 
170
  # Per-tensor H2D copies from CPU flat buffer slices.
171
  # non_blocking=True with pinned source allows DMA overlap.
@@ -177,8 +198,8 @@ class CPUOffloadPool:
177
  for i, mgd_idx in enumerate(indices):
178
  local = self._local(self._managed[mgd_idx])
179
  off, n = offsets[i]
180
- local.reshape(-1).copy_(
181
- cpu_flat[off:off + n], non_blocking=True)
182
 
183
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
184
 
 
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
+ storage = local.untyped_storage()
72
+ # Skip tensors with empty storage (e.g. empty FSDP shards)
73
+ if storage.size() == 0:
74
+ return
75
+ self._storage_nbytes[tid] = storage.size()
76
  self._managed.append(tensor)
77
 
78
  # ------------------------------------------------------------------
 
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
+ cpu_flat = torch.empty(off,
97
+ dtype=dtype,
98
+ device="cpu",
99
+ pin_memory=True)
100
  self._groups[dtype] = {
101
  "indices": indices,
102
  "offsets": offsets,
 
124
  self._ensure_stream()
125
 
126
  # Offload stream waits for compute to finish.
127
+ compute_event = torch.cuda.current_stream(self._device).record_event()
 
128
  self._offload_stream.wait_event(compute_event)
129
 
130
  offloaded_bytes = 0
 
140
  for i, mgd_idx in enumerate(indices):
141
  local = self._local(self._managed[mgd_idx])
142
  off, n = offsets[i]
143
+ cpu_flat[off:off + n].copy_(local.reshape(-1),
144
+ non_blocking=True)
145
 
146
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
147
 
148
  # Wait for all D2H copies to land, then free GPU storage.
149
  self._offload_stream.synchronize()
150
  for t in self._managed:
151
+ storage = self._local(t).untyped_storage()
152
+ if storage.size() != 0:
153
+ storage.resize_(0)
154
+ else:
155
+ raise RuntimeError(
156
+ f"Tensor storage is already freed (size=0) before offload. "
157
+ f"This indicates a double-free or external interference. "
158
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
159
+ )
160
 
161
  if not self._logged:
162
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
 
179
  # Re-allocate all GPU storages first.
180
  for t in self._managed:
181
  local = self._local(t)
182
+ storage = local.untyped_storage()
183
+ if storage.size() != 0:
184
+ raise RuntimeError(
185
+ f"Storage should have been freed (size=0) before reload, "
186
+ f"but got size={storage.size()}. "
187
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
188
+ )
189
+ storage.resize_(self._storage_nbytes[id(t)])
190
 
191
  # Per-tensor H2D copies from CPU flat buffer slices.
192
  # non_blocking=True with pinned source allows DMA overlap.
 
198
  for i, mgd_idx in enumerate(indices):
199
  local = self._local(self._managed[mgd_idx])
200
  off, n = offsets[i]
201
+ local.reshape(-1).copy_(cpu_flat[off:off + n],
202
+ non_blocking=True)
203
 
204
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
205
 
build/torch28-cxx11-rocm63-x86_64-linux/distributed/utils.py CHANGED
@@ -163,9 +163,10 @@ def construct_shard_mesh(
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
- # This avoids a non-collective dist.new_group() call, which would
167
- # deadlock when only a subset of ranks call this function (e.g. expert
168
- # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately).
 
169
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
170
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
171
  if key not in _ranks_to_dist_cache:
@@ -207,22 +208,25 @@ def construct_shard_mesh(
207
  assert len(shard_placements) == len(set(shard_placements))
208
 
209
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
210
- # All ranks must call dist.new_group in the same order, even though each
211
- # rank only joins one group.
 
 
212
  def _cache_key(t: torch.Tensor) -> tuple:
213
  return (*t.shape, *t.flatten().tolist())
214
 
215
  my_key = None
216
  for sm in shard_meshes:
217
- key = _cache_key(sm)
218
  if (my_rank == sm).any().item():
 
219
  assert my_key is None, "Rank appears in multiple shard groups"
220
  my_key = key
221
- if key not in _ranks_to_dist_cache:
222
- pg = dist.new_group(sm.flatten().tolist())
223
- _ranks_to_dist_cache[key] = (
224
- DeviceMesh(device_type="cuda", mesh=sm),
225
- pg,
226
- )
 
227
 
228
  return (*_ranks_to_dist_cache[my_key], shard_placements)
 
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
+ # Reuses the mesh's existing ProcessGroup directly, avoiding the
167
+ # overhead of dist.new_group(). The standard path below also handles
168
+ # subset calls safely via use_local_synchronization=True, but this
169
+ # fast path is still beneficial for the common 1D shard case.
170
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
171
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
172
  if key not in _ranks_to_dist_cache:
 
208
  assert len(shard_placements) == len(set(shard_placements))
209
 
210
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
211
+ # Each rank only creates the group it belongs to, using
212
+ # use_local_synchronization=True so that only group members need to
213
+ # coordinate. This avoids deadlocks when different PP stages call
214
+ # construct_shard_mesh for different parameters.
215
  def _cache_key(t: torch.Tensor) -> tuple:
216
  return (*t.shape, *t.flatten().tolist())
217
 
218
  my_key = None
219
  for sm in shard_meshes:
 
220
  if (my_rank == sm).any().item():
221
+ key = _cache_key(sm)
222
  assert my_key is None, "Rank appears in multiple shard groups"
223
  my_key = key
224
+ if key not in _ranks_to_dist_cache:
225
+ pg = dist.new_group(sm.flatten().tolist(),
226
+ use_local_synchronization=True)
227
+ _ranks_to_dist_cache[key] = (
228
+ DeviceMesh(device_type="cuda", mesh=sm),
229
+ pg,
230
+ )
231
 
232
  return (*_ranks_to_dist_cache[my_key], shard_placements)
build/torch28-cxx11-rocm63-x86_64-linux/muon.py CHANGED
@@ -8,7 +8,7 @@ import torch.distributed as dist
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
- from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
- expert_keys=None,
211
- cpu_offload=False):
212
  defaults = dict(
213
  lr=lr,
214
  weight_decay=weight_decay,
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
229
  if param_group.get("use_muon", None) is None:
230
  raise ValueError(
231
  error_message.format(idx=_idx) + instruction_code)
232
-
233
  super().__init__(params, defaults)
234
 
235
  self.debug = debug
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
243
  self.chunk_size = chunk_size
244
  self.use_distributed_muon = use_distributed_muon
245
  self.expert_keys = expert_keys
246
- self.cpu_offload = cpu_offload
247
- self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None
248
  self._offload_initialized = False
249
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
250
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
1008
  # D2H: offload optimizer states to CPU after computation.
1009
  if self.cpu_offload:
1010
  if not self._offload_initialized:
 
 
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
- # Checkpoint support for cpu_offload
1019
  # ------------------------------------------------------------------
1020
 
1021
- def state_dict(self) -> dict:
1022
- """Return optimizer state dict, reloading offloaded states first.
1023
-
1024
- When ``cpu_offload=True``, optimizer state tensors have their GPU
1025
- storage freed (``resize_(0)``) between steps. We reload them,
1026
- snapshot the state dict, then re-offload so the optimizer stays
1027
- in the expected post-step state. The returned dict holds cloned
1028
- tensors so they remain valid after the re-offload frees the
1029
- originals' GPU storage.
1030
- """
1031
- if self.cpu_offload and self._offload_initialized:
 
 
 
 
 
 
 
 
 
1032
  self._cpu_offload_pool.reload()
1033
  torch.cuda.current_stream().synchronize()
1034
- sd = super().state_dict()
1035
- if self.cpu_offload and self._offload_initialized:
1036
- # Clone state tensors so the returned dict survives re-offload
1037
- # (which frees GPU storage on the originals via resize_(0)).
1038
- for k in sd["state"]:
1039
- sd["state"][k] = {
1040
- sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
1041
- for sk, sv in sd["state"][k].items()
1042
- }
1043
- self._cpu_offload_pool.offload()
1044
- return sd
1045
 
1046
- def load_state_dict(self, state_dict: dict) -> None:
1047
- """Load optimizer state dict, then offload states if needed.
 
1048
 
1049
- After ``super().load_state_dict()`` populates GPU tensors, we
1050
- re-register them with the offload pool and offload to CPU so the
1051
- optimizer is in the same post-step state (GPU storage freed).
1052
- """
1053
- # If states were offloaded, reload first so storage sizes are
1054
- # correct for super().load_state_dict() to overwrite.
1055
- if self.cpu_offload and self._offload_initialized:
1056
- self._cpu_offload_pool.reload()
1057
- torch.cuda.current_stream().synchronize()
1058
 
 
 
 
1059
  super().load_state_dict(state_dict)
1060
 
1061
- if self.cpu_offload:
1062
- # Re-create the offload pool since state tensors may be new
1063
- # objects after load_state_dict.
1064
- self._cpu_offload_pool = CPUOffloadPool()
1065
- self._offload_initialized = False
1066
- self._register_states_for_offload()
1067
- self._offload_initialized = True
1068
- self._cpu_offload_pool.offload()
 
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
+ from .adamw import _placement_cache, _tensor_cache, step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
 
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
+ expert_keys=None):
 
211
  defaults = dict(
212
  lr=lr,
213
  weight_decay=weight_decay,
 
228
  if param_group.get("use_muon", None) is None:
229
  raise ValueError(
230
  error_message.format(idx=_idx) + instruction_code)
 
231
  super().__init__(params, defaults)
232
 
233
  self.debug = debug
 
241
  self.chunk_size = chunk_size
242
  self.use_distributed_muon = use_distributed_muon
243
  self.expert_keys = expert_keys
244
+ self.cpu_offload = False
245
+ self._cpu_offload_pool: CPUOffloadPool | None = None
246
  self._offload_initialized = False
247
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
248
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
 
1006
  # D2H: offload optimizer states to CPU after computation.
1007
  if self.cpu_offload:
1008
  if not self._offload_initialized:
1009
+ if self._cpu_offload_pool is None:
1010
+ self._cpu_offload_pool = CPUOffloadPool()
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
 
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
+ # CPU offload public helpers
1019
  # ------------------------------------------------------------------
1020
 
1021
+ def turn_on_cpu_offload(self):
1022
+ """Enable CPU offload for optimizer states."""
1023
+ if self.cpu_offload:
1024
+ return
1025
+ logger.info("[Muon] turn_on_cpu_offload")
1026
+ self.cpu_offload = True
1027
+ if not self.state:
1028
+ return
1029
+ self._cpu_offload_pool = CPUOffloadPool()
1030
+ self._offload_initialized = False
1031
+ self._register_states_for_offload()
1032
+ self._offload_initialized = True
1033
+ self._cpu_offload_pool.offload()
1034
+
1035
+ def turn_off_cpu_offload(self):
1036
+ """Disable CPU offload and keep optimizer states resident on GPU."""
1037
+ if not self.cpu_offload:
1038
+ return
1039
+ logger.info("[Muon] turn_off_cpu_offload")
1040
+ if self._offload_initialized:
1041
  self._cpu_offload_pool.reload()
1042
  torch.cuda.current_stream().synchronize()
1043
+ self._cpu_offload_pool = None
1044
+ self._offload_initialized = False
1045
+ self.cpu_offload = False
 
 
 
 
 
 
 
 
1046
 
1047
+ # ------------------------------------------------------------------
1048
+ # Checkpoint support for cpu_offload
1049
+ # ------------------------------------------------------------------
1050
 
1051
+ def state_dict(self) -> dict:
1052
+ if self.cpu_offload:
1053
+ raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
1054
+ return super().state_dict()
 
 
 
 
 
1055
 
1056
+ def load_state_dict(self, state_dict: dict) -> None:
1057
+ if self.cpu_offload:
1058
+ raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
1059
  super().load_state_dict(state_dict)
1060
 
1061
+ # Invalidate adamw.py's module-level tensor caches so that
1062
+ # the next step rebuilds them with the newly loaded state tensors.
1063
+ _placement_cache.clear()
1064
+ _tensor_cache.clear()
 
 
 
 
build/torch28-cxx11-rocm64-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_5b58933_dirty
3
- ops = torch.ops._optimizer_5b58933_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_5b58933_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_b68ea5b_dirty
3
+ ops = torch.ops._optimizer_b68ea5b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_b68ea5b_dirty::{op_name}"
build/torch28-cxx11-rocm64-x86_64-linux/{_optimizer_5b58933_dirty.abi3.so → _optimizer_b68ea5b_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fd57f2197a2107ad920abbce3e2c986b79c76cb864f693f53bd389b26b763902
3
  size 1865168
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c562a3b7e6f3032ff56473531e9e08fceb2c86f8804080330c896ab8f0dd32af
3
  size 1865168
build/torch28-cxx11-rocm64-x86_64-linux/cpu_offload.py CHANGED
@@ -68,7 +68,11 @@ class CPUOffloadPool:
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
- self._storage_nbytes[tid] = local.untyped_storage().size()
 
 
 
 
72
  self._managed.append(tensor)
73
 
74
  # ------------------------------------------------------------------
@@ -89,7 +93,10 @@ class CPUOffloadPool:
89
  indices.append(idx)
90
  offsets.append((off, n))
91
  off += n
92
- cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
 
 
 
93
  self._groups[dtype] = {
94
  "indices": indices,
95
  "offsets": offsets,
@@ -117,8 +124,7 @@ class CPUOffloadPool:
117
  self._ensure_stream()
118
 
119
  # Offload stream waits for compute to finish.
120
- compute_event = torch.cuda.current_stream(
121
- self._device).record_event()
122
  self._offload_stream.wait_event(compute_event)
123
 
124
  offloaded_bytes = 0
@@ -134,15 +140,23 @@ class CPUOffloadPool:
134
  for i, mgd_idx in enumerate(indices):
135
  local = self._local(self._managed[mgd_idx])
136
  off, n = offsets[i]
137
- cpu_flat[off:off + n].copy_(
138
- local.reshape(-1), non_blocking=True)
139
 
140
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
141
 
142
  # Wait for all D2H copies to land, then free GPU storage.
143
  self._offload_stream.synchronize()
144
  for t in self._managed:
145
- self._local(t).untyped_storage().resize_(0)
 
 
 
 
 
 
 
 
146
 
147
  if not self._logged:
148
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
@@ -165,7 +179,14 @@ class CPUOffloadPool:
165
  # Re-allocate all GPU storages first.
166
  for t in self._managed:
167
  local = self._local(t)
168
- local.untyped_storage().resize_(self._storage_nbytes[id(t)])
 
 
 
 
 
 
 
169
 
170
  # Per-tensor H2D copies from CPU flat buffer slices.
171
  # non_blocking=True with pinned source allows DMA overlap.
@@ -177,8 +198,8 @@ class CPUOffloadPool:
177
  for i, mgd_idx in enumerate(indices):
178
  local = self._local(self._managed[mgd_idx])
179
  off, n = offsets[i]
180
- local.reshape(-1).copy_(
181
- cpu_flat[off:off + n], non_blocking=True)
182
 
183
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
184
 
 
68
  local = self._local(tensor)
69
  if self._device is None:
70
  self._device = local.device
71
+ storage = local.untyped_storage()
72
+ # Skip tensors with empty storage (e.g. empty FSDP shards)
73
+ if storage.size() == 0:
74
+ return
75
+ self._storage_nbytes[tid] = storage.size()
76
  self._managed.append(tensor)
77
 
78
  # ------------------------------------------------------------------
 
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
+ cpu_flat = torch.empty(off,
97
+ dtype=dtype,
98
+ device="cpu",
99
+ pin_memory=True)
100
  self._groups[dtype] = {
101
  "indices": indices,
102
  "offsets": offsets,
 
124
  self._ensure_stream()
125
 
126
  # Offload stream waits for compute to finish.
127
+ compute_event = torch.cuda.current_stream(self._device).record_event()
 
128
  self._offload_stream.wait_event(compute_event)
129
 
130
  offloaded_bytes = 0
 
140
  for i, mgd_idx in enumerate(indices):
141
  local = self._local(self._managed[mgd_idx])
142
  off, n = offsets[i]
143
+ cpu_flat[off:off + n].copy_(local.reshape(-1),
144
+ non_blocking=True)
145
 
146
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
147
 
148
  # Wait for all D2H copies to land, then free GPU storage.
149
  self._offload_stream.synchronize()
150
  for t in self._managed:
151
+ storage = self._local(t).untyped_storage()
152
+ if storage.size() != 0:
153
+ storage.resize_(0)
154
+ else:
155
+ raise RuntimeError(
156
+ f"Tensor storage is already freed (size=0) before offload. "
157
+ f"This indicates a double-free or external interference. "
158
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
159
+ )
160
 
161
  if not self._logged:
162
  logger.info("[CPUOffload] Offloaded %.2f MB (GPU → CPU)",
 
179
  # Re-allocate all GPU storages first.
180
  for t in self._managed:
181
  local = self._local(t)
182
+ storage = local.untyped_storage()
183
+ if storage.size() != 0:
184
+ raise RuntimeError(
185
+ f"Storage should have been freed (size=0) before reload, "
186
+ f"but got size={storage.size()}. "
187
+ f"Tensor shape: {t.shape}, dtype: {t.dtype}"
188
+ )
189
+ storage.resize_(self._storage_nbytes[id(t)])
190
 
191
  # Per-tensor H2D copies from CPU flat buffer slices.
192
  # non_blocking=True with pinned source allows DMA overlap.
 
198
  for i, mgd_idx in enumerate(indices):
199
  local = self._local(self._managed[mgd_idx])
200
  off, n = offsets[i]
201
+ local.reshape(-1).copy_(cpu_flat[off:off + n],
202
+ non_blocking=True)
203
 
204
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
205
 
build/torch28-cxx11-rocm64-x86_64-linux/distributed/utils.py CHANGED
@@ -163,9 +163,10 @@ def construct_shard_mesh(
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
- # This avoids a non-collective dist.new_group() call, which would
167
- # deadlock when only a subset of ranks call this function (e.g. expert
168
- # DTensors on a TP submesh where ranks 0-3 and 4-7 call separately).
 
169
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
170
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
171
  if key not in _ranks_to_dist_cache:
@@ -207,22 +208,25 @@ def construct_shard_mesh(
207
  assert len(shard_placements) == len(set(shard_placements))
208
 
209
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
210
- # All ranks must call dist.new_group in the same order, even though each
211
- # rank only joins one group.
 
 
212
  def _cache_key(t: torch.Tensor) -> tuple:
213
  return (*t.shape, *t.flatten().tolist())
214
 
215
  my_key = None
216
  for sm in shard_meshes:
217
- key = _cache_key(sm)
218
  if (my_rank == sm).any().item():
 
219
  assert my_key is None, "Rank appears in multiple shard groups"
220
  my_key = key
221
- if key not in _ranks_to_dist_cache:
222
- pg = dist.new_group(sm.flatten().tolist())
223
- _ranks_to_dist_cache[key] = (
224
- DeviceMesh(device_type="cuda", mesh=sm),
225
- pg,
226
- )
 
227
 
228
  return (*_ranks_to_dist_cache[my_key], shard_placements)
 
163
  assert mesh.mesh.device.type == 'cpu'
164
 
165
  # -- Fast path: 1D all-shard mesh → reuse existing PG. ----------------
166
+ # Reuses the mesh's existing ProcessGroup directly, avoiding the
167
+ # overhead of dist.new_group(). The standard path below also handles
168
+ # subset calls safely via use_local_synchronization=True, but this
169
+ # fast path is still beneficial for the common 1D shard case.
170
  if mesh.ndim == 1 and len(placements) == 1 and _is_shard(placements[0]):
171
  key = (*mesh.mesh.shape, *mesh.mesh.flatten().tolist())
172
  if key not in _ranks_to_dist_cache:
 
208
  assert len(shard_placements) == len(set(shard_placements))
209
 
210
  # -- Step 4: Create/retrieve ProcessGroup for current rank's sub-mesh. --
211
+ # Each rank only creates the group it belongs to, using
212
+ # use_local_synchronization=True so that only group members need to
213
+ # coordinate. This avoids deadlocks when different PP stages call
214
+ # construct_shard_mesh for different parameters.
215
  def _cache_key(t: torch.Tensor) -> tuple:
216
  return (*t.shape, *t.flatten().tolist())
217
 
218
  my_key = None
219
  for sm in shard_meshes:
 
220
  if (my_rank == sm).any().item():
221
+ key = _cache_key(sm)
222
  assert my_key is None, "Rank appears in multiple shard groups"
223
  my_key = key
224
+ if key not in _ranks_to_dist_cache:
225
+ pg = dist.new_group(sm.flatten().tolist(),
226
+ use_local_synchronization=True)
227
+ _ranks_to_dist_cache[key] = (
228
+ DeviceMesh(device_type="cuda", mesh=sm),
229
+ pg,
230
+ )
231
 
232
  return (*_ranks_to_dist_cache[my_key], shard_placements)
build/torch28-cxx11-rocm64-x86_64-linux/muon.py CHANGED
@@ -8,7 +8,7 @@ import torch.distributed as dist
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
- from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
@@ -207,8 +207,7 @@ class Muon(torch.optim.Optimizer):
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
- expert_keys=None,
211
- cpu_offload=False):
212
  defaults = dict(
213
  lr=lr,
214
  weight_decay=weight_decay,
@@ -229,7 +228,6 @@ class Muon(torch.optim.Optimizer):
229
  if param_group.get("use_muon", None) is None:
230
  raise ValueError(
231
  error_message.format(idx=_idx) + instruction_code)
232
-
233
  super().__init__(params, defaults)
234
 
235
  self.debug = debug
@@ -243,8 +241,8 @@ class Muon(torch.optim.Optimizer):
243
  self.chunk_size = chunk_size
244
  self.use_distributed_muon = use_distributed_muon
245
  self.expert_keys = expert_keys
246
- self.cpu_offload = cpu_offload
247
- self._cpu_offload_pool = CPUOffloadPool() if cpu_offload else None
248
  self._offload_initialized = False
249
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
250
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
@@ -1008,6 +1006,8 @@ class Muon(torch.optim.Optimizer):
1008
  # D2H: offload optimizer states to CPU after computation.
1009
  if self.cpu_offload:
1010
  if not self._offload_initialized:
 
 
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
@@ -1015,54 +1015,50 @@ class Muon(torch.optim.Optimizer):
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
- # Checkpoint support for cpu_offload
1019
  # ------------------------------------------------------------------
1020
 
1021
- def state_dict(self) -> dict:
1022
- """Return optimizer state dict, reloading offloaded states first.
1023
-
1024
- When ``cpu_offload=True``, optimizer state tensors have their GPU
1025
- storage freed (``resize_(0)``) between steps. We reload them,
1026
- snapshot the state dict, then re-offload so the optimizer stays
1027
- in the expected post-step state. The returned dict holds cloned
1028
- tensors so they remain valid after the re-offload frees the
1029
- originals' GPU storage.
1030
- """
1031
- if self.cpu_offload and self._offload_initialized:
 
 
 
 
 
 
 
 
 
1032
  self._cpu_offload_pool.reload()
1033
  torch.cuda.current_stream().synchronize()
1034
- sd = super().state_dict()
1035
- if self.cpu_offload and self._offload_initialized:
1036
- # Clone state tensors so the returned dict survives re-offload
1037
- # (which frees GPU storage on the originals via resize_(0)).
1038
- for k in sd["state"]:
1039
- sd["state"][k] = {
1040
- sk: sv.clone() if isinstance(sv, torch.Tensor) else sv
1041
- for sk, sv in sd["state"][k].items()
1042
- }
1043
- self._cpu_offload_pool.offload()
1044
- return sd
1045
 
1046
- def load_state_dict(self, state_dict: dict) -> None:
1047
- """Load optimizer state dict, then offload states if needed.
 
1048
 
1049
- After ``super().load_state_dict()`` populates GPU tensors, we
1050
- re-register them with the offload pool and offload to CPU so the
1051
- optimizer is in the same post-step state (GPU storage freed).
1052
- """
1053
- # If states were offloaded, reload first so storage sizes are
1054
- # correct for super().load_state_dict() to overwrite.
1055
- if self.cpu_offload and self._offload_initialized:
1056
- self._cpu_offload_pool.reload()
1057
- torch.cuda.current_stream().synchronize()
1058
 
 
 
 
1059
  super().load_state_dict(state_dict)
1060
 
1061
- if self.cpu_offload:
1062
- # Re-create the offload pool since state tensors may be new
1063
- # objects after load_state_dict.
1064
- self._cpu_offload_pool = CPUOffloadPool()
1065
- self._offload_initialized = False
1066
- self._register_states_for_offload()
1067
- self._offload_initialized = True
1068
- self._cpu_offload_pool.offload()
 
8
  from torch.distributed.tensor import DTensor, Replicate, Shard
9
  from torch.profiler import record_function
10
 
11
+ from .adamw import _placement_cache, _tensor_cache, step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho,
14
  get_default_muon_param_groups, is_expert_param, update_p)
 
207
  warmup_step=5,
208
  chunk_size=-1,
209
  use_distributed_muon=False,
210
+ expert_keys=None):
 
211
  defaults = dict(
212
  lr=lr,
213
  weight_decay=weight_decay,
 
228
  if param_group.get("use_muon", None) is None:
229
  raise ValueError(
230
  error_message.format(idx=_idx) + instruction_code)
 
231
  super().__init__(params, defaults)
232
 
233
  self.debug = debug
 
241
  self.chunk_size = chunk_size
242
  self.use_distributed_muon = use_distributed_muon
243
  self.expert_keys = expert_keys
244
+ self.cpu_offload = False
245
+ self._cpu_offload_pool: CPUOffloadPool | None = None
246
  self._offload_initialized = False
247
  self._parallel_cache: dict[tuple[str, ...], dict] = {}
248
  self._expert_expand_cache: dict[tuple[int, ...], dict] = {}
 
1006
  # D2H: offload optimizer states to CPU after computation.
1007
  if self.cpu_offload:
1008
  if not self._offload_initialized:
1009
+ if self._cpu_offload_pool is None:
1010
+ self._cpu_offload_pool = CPUOffloadPool()
1011
  self._register_states_for_offload()
1012
  self._offload_initialized = True
1013
  self._cpu_offload_pool.offload()
 
1015
  return loss
1016
 
1017
  # ------------------------------------------------------------------
1018
+ # CPU offload public helpers
1019
  # ------------------------------------------------------------------
1020
 
1021
+ def turn_on_cpu_offload(self):
1022
+ """Enable CPU offload for optimizer states."""
1023
+ if self.cpu_offload:
1024
+ return
1025
+ logger.info("[Muon] turn_on_cpu_offload")
1026
+ self.cpu_offload = True
1027
+ if not self.state:
1028
+ return
1029
+ self._cpu_offload_pool = CPUOffloadPool()
1030
+ self._offload_initialized = False
1031
+ self._register_states_for_offload()
1032
+ self._offload_initialized = True
1033
+ self._cpu_offload_pool.offload()
1034
+
1035
+ def turn_off_cpu_offload(self):
1036
+ """Disable CPU offload and keep optimizer states resident on GPU."""
1037
+ if not self.cpu_offload:
1038
+ return
1039
+ logger.info("[Muon] turn_off_cpu_offload")
1040
+ if self._offload_initialized:
1041
  self._cpu_offload_pool.reload()
1042
  torch.cuda.current_stream().synchronize()
1043
+ self._cpu_offload_pool = None
1044
+ self._offload_initialized = False
1045
+ self.cpu_offload = False
 
 
 
 
 
 
 
 
1046
 
1047
+ # ------------------------------------------------------------------
1048
+ # Checkpoint support for cpu_offload
1049
+ # ------------------------------------------------------------------
1050
 
1051
+ def state_dict(self) -> dict:
1052
+ if self.cpu_offload:
1053
+ raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
1054
+ return super().state_dict()
 
 
 
 
 
1055
 
1056
+ def load_state_dict(self, state_dict: dict) -> None:
1057
+ if self.cpu_offload:
1058
+ raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
1059
  super().load_state_dict(state_dict)
1060
 
1061
+ # Invalidate adamw.py's module-level tensor caches so that
1062
+ # the next step rebuilds them with the newly loaded state tensors.
1063
+ _placement_cache.clear()
1064
+ _tensor_cache.clear()