ZhenbinWang commited on
Commit
805d830
·
verified ·
1 Parent(s): 0bedf62

Upload 50 files

Browse files
Files changed (50) hide show
  1. flame/__init__.py +1 -0
  2. flame/__pycache__/__init__.cpython-310.pyc +0 -0
  3. flame/__pycache__/__init__.cpython-311.pyc +0 -0
  4. flame/__pycache__/__init__.cpython-312.pyc +0 -0
  5. flame/__pycache__/config_manager.cpython-311.pyc +0 -0
  6. flame/__pycache__/config_manager.cpython-312.pyc +0 -0
  7. flame/__pycache__/data.cpython-311.pyc +0 -0
  8. flame/__pycache__/data.cpython-312.pyc +0 -0
  9. flame/__pycache__/train.cpython-310.pyc +0 -0
  10. flame/__pycache__/train.cpython-311.pyc +0 -0
  11. flame/__pycache__/train.cpython-312.pyc +0 -0
  12. flame/__pycache__/train_restart.cpython-311.pyc +0 -0
  13. flame/components/__init__.py +0 -0
  14. flame/components/__pycache__/__init__.cpython-311.pyc +0 -0
  15. flame/components/__pycache__/__init__.cpython-312.pyc +0 -0
  16. flame/components/__pycache__/checkpoint.cpython-311.pyc +0 -0
  17. flame/components/__pycache__/checkpoint.cpython-312.pyc +0 -0
  18. flame/components/checkpoint.py +59 -0
  19. flame/config_manager.py +960 -0
  20. flame/data.py +757 -0
  21. flame/models/__init__.py +0 -0
  22. flame/models/__pycache__/__init__.cpython-311.pyc +0 -0
  23. flame/models/__pycache__/__init__.cpython-312.pyc +0 -0
  24. flame/models/__pycache__/parallelize_fla.cpython-311.pyc +0 -0
  25. flame/models/__pycache__/parallelize_fla.cpython-312.pyc +0 -0
  26. flame/models/__pycache__/pipeline_fla.cpython-311.pyc +0 -0
  27. flame/models/__pycache__/pipeline_fla.cpython-312.pyc +0 -0
  28. flame/models/activation_offloading.py +447 -0
  29. flame/models/fla.toml +67 -0
  30. flame/models/parallelize_fla.py +550 -0
  31. flame/models/pipeline_fla.py +162 -0
  32. flame/tools/__init__.py +0 -0
  33. flame/tools/__pycache__/__init__.cpython-311.pyc +0 -0
  34. flame/tools/__pycache__/__init__.cpython-312.pyc +0 -0
  35. flame/tools/__pycache__/utils.cpython-311.pyc +0 -0
  36. flame/tools/__pycache__/utils.cpython-312.pyc +0 -0
  37. flame/tools/utils.py +41 -0
  38. flame/train.py +637 -0
  39. flame/train2.py +625 -0
  40. flame/train_restart.py +694 -0
  41. flame/utils/__init__.py +0 -0
  42. flame/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  43. flame/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  44. flame/utils/__pycache__/__init__.cpython-312.pyc +0 -0
  45. flame/utils/__pycache__/convert_dcp_to_hf.cpython-310.pyc +0 -0
  46. flame/utils/__pycache__/convert_dcp_to_hf.cpython-311.pyc +0 -0
  47. flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc +0 -0
  48. flame/utils/convert_dcp_to_hf.py +74 -0
  49. flame/utils/convert_hf_to_dcp.py +34 -0
  50. flame/utils/preprocess.py +122 -0
flame/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.1.0"
flame/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (153 Bytes). View file
 
flame/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (168 Bytes). View file
 
flame/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (162 Bytes). View file
 
flame/__pycache__/config_manager.cpython-311.pyc ADDED
Binary file (40.6 kB). View file
 
flame/__pycache__/config_manager.cpython-312.pyc ADDED
Binary file (38.8 kB). View file
 
flame/__pycache__/data.cpython-311.pyc ADDED
Binary file (41.6 kB). View file
 
flame/__pycache__/data.cpython-312.pyc ADDED
Binary file (37.6 kB). View file
 
flame/__pycache__/train.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
flame/__pycache__/train.cpython-311.pyc ADDED
Binary file (25.9 kB). View file
 
flame/__pycache__/train.cpython-312.pyc ADDED
Binary file (26 kB). View file
 
flame/__pycache__/train_restart.cpython-311.pyc ADDED
Binary file (26 kB). View file
 
flame/components/__init__.py ADDED
File without changes
flame/components/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (159 Bytes). View file
 
flame/components/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (147 Bytes). View file
 
flame/components/__pycache__/checkpoint.cpython-311.pyc ADDED
Binary file (3.63 kB). View file
 
flame/components/__pycache__/checkpoint.cpython-312.pyc ADDED
Binary file (3.21 kB). View file
 
flame/components/checkpoint.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+ from datetime import timedelta
9
+ from io import BytesIO
10
+ from typing import Any, Dict, List
11
+
12
+ import torch
13
+ from torch.distributed.checkpoint.stateful import Stateful
14
+
15
+
16
+ @dataclass
17
+ class TrainState(Stateful):
18
+ step: int = 0
19
+ skipped_step: int = 0
20
+ token: int = 0
21
+ elapsed: timedelta = timedelta(0)
22
+ global_avg_losses: List[float] = field(default_factory=list)
23
+ global_max_losses: List[float] = field(default_factory=list)
24
+ log_steps: List[int] = field(default_factory=list)
25
+
26
+ def state_dict(self) -> Dict[str, Any]:
27
+ # Only checkpoint global_avg_losses and global_max_losses per log frequency
28
+ # to avoid sync overhead in every iteration.
29
+ global_avg_losses_bytes = BytesIO()
30
+ torch.save(self.global_avg_losses, global_avg_losses_bytes)
31
+ global_max_losses_bytes = BytesIO()
32
+ torch.save(self.global_max_losses, global_max_losses_bytes)
33
+ log_steps_bytes = BytesIO()
34
+ torch.save(self.log_steps, log_steps_bytes)
35
+ return {
36
+ "step": torch.tensor(self.step, dtype=torch.int32),
37
+ "skipped_step": torch.tensor(self.skipped_step, dtype=torch.int32),
38
+ "token": torch.tensor(self.token, dtype=torch.int64),
39
+ "elapsed": self.elapsed,
40
+ "global_avg_losses": global_avg_losses_bytes,
41
+ "global_max_losses": global_max_losses_bytes,
42
+ "log_steps": log_steps_bytes,
43
+ }
44
+
45
+ def load_state_dict(self, state_dict) -> None:
46
+ self.step = state_dict["step"].item()
47
+ self.skipped_step = state_dict.get("skipped_step", 0).item()
48
+ self.token = state_dict["token"].item()
49
+ self.elapsed = state_dict["elapsed"]
50
+ state_dict["global_avg_losses"].seek(0)
51
+ self.global_avg_losses = torch.load(
52
+ state_dict["global_avg_losses"], weights_only=False
53
+ )
54
+ state_dict["global_max_losses"].seek(0)
55
+ self.global_max_losses = torch.load(
56
+ state_dict["global_max_losses"], weights_only=False
57
+ )
58
+ state_dict["log_steps"].seek(0)
59
+ self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
flame/config_manager.py ADDED
@@ -0,0 +1,960 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import sys
9
+ from collections import defaultdict
10
+ from typing import Tuple
11
+
12
+ import torch
13
+
14
+ try:
15
+ import tomllib
16
+ except ModuleNotFoundError:
17
+ import tomli as tomllib
18
+
19
+ from torchtitan.tools.logging import logger
20
+
21
+ TORCH_DTYPE_MAP = {
22
+ "float16": torch.float16,
23
+ "float32": torch.float32,
24
+ "bfloat16": torch.bfloat16,
25
+ }
26
+
27
+
28
+ def string_list(raw_arg):
29
+ """Comma-separated string list argument."""
30
+ return [s.strip() for s in raw_arg.split(",") if s.strip()]
31
+
32
+
33
+ def check_string_list_argument(args_dict: dict[str, any], fullargname: str):
34
+ section, name = fullargname.split(".")
35
+ # Split string list which are still raw strings.
36
+ if (
37
+ section in args_dict
38
+ and name in args_dict[section]
39
+ and isinstance(args_dict[section][name], str)
40
+ ):
41
+ sec = args_dict[section]
42
+ sec[name] = string_list(sec[name])
43
+
44
+
45
+ class JobConfig:
46
+ """
47
+ A helper class to manage the train configuration.
48
+ Semantics:
49
+ - Default config is loaded from a toml file. If no toml file is provided,
50
+ then the default config is loaded from argparse defaults.
51
+ - if toml file has missing keys, they are filled with argparse defaults.
52
+ - if additional explicit cmd args are provided in addition to the toml
53
+ file, they will override the toml config and the argparse defaults
54
+
55
+ precedence order: cmdline > toml > argparse default
56
+
57
+ Arg parsing semantics:
58
+
59
+ Each argument starts with <prefix>_ which is the section name in the toml file
60
+ followed by name of the option in the toml file. For ex,
61
+ model.name translates to:
62
+ [model]
63
+ name
64
+ in the toml file
65
+ """
66
+
67
+ def __init__(self):
68
+ self.args_dict = None
69
+ # main parser
70
+ self.parser = argparse.ArgumentParser(description="torchtitan arg parser.")
71
+
72
+ self.parser.add_argument(
73
+ "--job.config_file",
74
+ type=str,
75
+ default=None,
76
+ help="Job config file",
77
+ )
78
+
79
+ # job level configs
80
+ self.parser.add_argument(
81
+ "--job.dump_folder",
82
+ type=str,
83
+ default="./torchtitan/outputs",
84
+ help="Folder to dump job outputs",
85
+ )
86
+ self.parser.add_argument(
87
+ "--job.description",
88
+ type=str,
89
+ default="default job",
90
+ help="Description of the job",
91
+ )
92
+ self.parser.add_argument(
93
+ "--job.use_for_integration_test",
94
+ action="store_true",
95
+ help="Add this config to the integration test suite",
96
+ )
97
+ self.parser.add_argument(
98
+ "--job.print_args",
99
+ action="store_true",
100
+ help="Print the args to terminal",
101
+ )
102
+
103
+ # model configs
104
+ self.parser.add_argument(
105
+ "--model.name",
106
+ type=str,
107
+ default="fla",
108
+ help="Which model to train",
109
+ )
110
+ self.parser.add_argument(
111
+ "--model.config",
112
+ type=str,
113
+ default="fla-hub/transformer-1.3B-100B",
114
+ help="Path to the model config",
115
+ )
116
+ self.parser.add_argument(
117
+ "--model.tokenizer_path",
118
+ type=str,
119
+ default="fla-hub/transformer-1.3B-100B",
120
+ help="Tokenizer path",
121
+ )
122
+ self.parser.add_argument(
123
+ "--model.converters",
124
+ type=string_list,
125
+ nargs="+",
126
+ default=[],
127
+ help="""
128
+ Comma separated list of converters to apply to the model.
129
+ For instance, the `float8` converter swaps `torch.nn.Linear`
130
+ with `Float8Linear`. This feature requires you to install 'torchao'
131
+ which can be found here: https://github.com/pytorch/ao
132
+ """,
133
+ )
134
+ self.parser.add_argument(
135
+ "--model.print_after_conversion",
136
+ action="store_true",
137
+ help="""
138
+ If true, model definition will be printed to stdout after all model
139
+ converters have been applied.
140
+ """,
141
+ )
142
+
143
+ # profiling configs
144
+ self.parser.add_argument(
145
+ "--profiling.enable_profiling",
146
+ action="store_true",
147
+ help="Whether to enable pytorch profiler",
148
+ )
149
+ self.parser.add_argument(
150
+ "--profiling.save_traces_folder",
151
+ type=str,
152
+ default="profile_traces",
153
+ help="Trace files location",
154
+ )
155
+ self.parser.add_argument(
156
+ "--profiling.profile_freq",
157
+ type=int,
158
+ default=10,
159
+ help="How often to collect profiler traces, in iterations",
160
+ )
161
+ self.parser.add_argument(
162
+ "--profiling.enable_memory_snapshot",
163
+ action="store_true",
164
+ help="Whether to dump memory snapshot",
165
+ )
166
+ self.parser.add_argument(
167
+ "--profiling.save_memory_snapshot_folder",
168
+ type=str,
169
+ default="memory_snapshot",
170
+ help="Memeory snapshot files location",
171
+ )
172
+
173
+ # optimizer configs
174
+ self.parser.add_argument(
175
+ "--optimizer.name", type=str, default="AdamW", help="Optimizer to use"
176
+ )
177
+ self.parser.add_argument(
178
+ "--optimizer.eps",
179
+ type=float,
180
+ default=1e-8,
181
+ help="Epsilon value for the optimizer.",
182
+ )
183
+ self.parser.add_argument(
184
+ "--optimizer.lr", type=float, default=8e-4, help="Learning rate to use"
185
+ )
186
+ self.parser.add_argument(
187
+ "--optimizer.beta1", type=float, default=0.9,
188
+ help="Exponential moving average hyperparameters to use"
189
+ )
190
+ self.parser.add_argument(
191
+ "--optimizer.beta2", type=float, default=0.95,
192
+ help="Exponential moving average hyperparameters to use"
193
+ )
194
+ self.parser.add_argument(
195
+ "--optimizer.weight_decay", type=float, default=0.1,
196
+ help="Weight decay to use"
197
+ )
198
+ self.parser.add_argument(
199
+ "--optimizer.implementation",
200
+ type=str,
201
+ default="fused",
202
+ choices=["for-loop", "foreach", "fused"],
203
+ help="""
204
+ Specify which optimizer implementation to use:
205
+ - 'fused': Use fused implementation (CUDA only) for best performance.
206
+ - 'foreach': Use some horizontal fusion of tensors for better performance.
207
+ - 'for-loop': Use the default implementation for the optimizer (slowest).
208
+ - more info: https://pytorch.org/docs/stable/optim.html
209
+ """,
210
+ )
211
+ self.parser.add_argument(
212
+ "--optimizer.early_step_in_backward",
213
+ action="store_true",
214
+ help="""
215
+ Whether to apply optimizer in the backward. Caution, optimizer_in_backward
216
+ is not compatible with gradients clipping, users should not call
217
+ register_post_accumulate_grad_hook after the optimizer is built.""",
218
+ )
219
+
220
+ # lr scheduler configs
221
+ self.parser.add_argument(
222
+ "--lr_scheduler.warmup_steps",
223
+ type=int,
224
+ default=200,
225
+ help="Steps for lr scheduler warmup, normally 1/5 of --training.steps",
226
+ )
227
+ self.parser.add_argument(
228
+ "--lr_scheduler.decay_ratio",
229
+ type=float,
230
+ default=None,
231
+ help="""
232
+ Controls the proportion of the training steps allocated to the learning rate decay phase.
233
+
234
+ If `None`, the learning rate will begin decaying immediately after the warmup period.
235
+ Otherwise, the learning rate will remain stable after the warmup period and
236
+ only start decaying during the last `decay_ratio` portion of the total training steps.
237
+
238
+ This is known as the Warmup-Stable-Decay (WSD) schedule, as described in https://arxiv.org/abs/2404.06395.
239
+ """,
240
+ )
241
+ self.parser.add_argument(
242
+ "--lr_scheduler.decay_type",
243
+ type=str,
244
+ default="linear",
245
+ choices=["linear", "sqrt", "cosine"],
246
+ help="""
247
+ Learning rate decay type to use during training:
248
+ - 'linear': linearly decays learning rate from initial to final value
249
+ - 'sqrt': decays learning rate following a 1 minus square root curve
250
+ - 'cosine': smoothly decays learning rate following a cosine curve
251
+ """,
252
+ )
253
+ self.parser.add_argument(
254
+ "--lr_scheduler.lr_min",
255
+ type=float,
256
+ default=0.0,
257
+ help="""
258
+ Min lr ratio for lr scheduler.
259
+
260
+ If provided, the range of decay factor is scaled from 1 to `lr_min`
261
+ to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.lr_min`.
262
+ """,
263
+ )
264
+
265
+ # training configs
266
+ self.parser.add_argument(
267
+ "--training.batch_size", type=int, default=8, help="Batch size"
268
+ )
269
+ self.parser.add_argument(
270
+ "--training.seq_len", type=int, default=2048, help="Sequence length"
271
+ )
272
+ self.parser.add_argument(
273
+ "--training.context_len",
274
+ type=int,
275
+ default=2048,
276
+ help="Max length allowed for each sequence",
277
+ )
278
+ self.parser.add_argument(
279
+ "--training.varlen",
280
+ action="store_true",
281
+ help="Whether to take sequences of variable length as input",
282
+ )
283
+ self.parser.add_argument(
284
+ "--training.gradient_accumulation_steps",
285
+ type=int,
286
+ default=1,
287
+ help="Number of steps to accumulate gradients before updating parameters",
288
+ )
289
+ self.parser.add_argument(
290
+ "--training.steps",
291
+ type=int,
292
+ default=10000,
293
+ help="How many train steps to run",
294
+ )
295
+ self.parser.add_argument(
296
+ "--training.max_norm",
297
+ type=float,
298
+ default=1.0,
299
+ help="Max norm for gradient clipping",
300
+ )
301
+ self.parser.add_argument(
302
+ "--training.skip_nan_inf",
303
+ action="store_true",
304
+ help="Skip batch updates when NaN or INF gradients are encountered during training",
305
+ )
306
+ self.parser.add_argument(
307
+ "--training.dataset",
308
+ default="HuggingFaceFW/fineweb-edu",
309
+ help="Dataset to use, with comma separated values",
310
+ )
311
+ self.parser.add_argument(
312
+ "--training.dataset_name",
313
+ default=None,
314
+ help="The name of the dataset config, with comma separated values if provided",
315
+ )
316
+ self.parser.add_argument(
317
+ "--training.dataset_split",
318
+ default=None,
319
+ help="Dataset split to use, with comma separated values if provided",
320
+ )
321
+ self.parser.add_argument(
322
+ "--training.data_dir",
323
+ default=None,
324
+ help="Data dirs to use, with comma separated values if provided",
325
+ )
326
+ self.parser.add_argument(
327
+ "--training.data_files",
328
+ default=None,
329
+ help="Data files to use, with comma separated values if provided",
330
+ )
331
+ self.parser.add_argument(
332
+ "--training.data_probs",
333
+ default=None,
334
+ help="Data sampling probabilities, with comma separated values if provided",
335
+ )
336
+ self.parser.add_argument(
337
+ "--training.streaming",
338
+ action="store_true",
339
+ help="Whether to load dataset in streaming mode, used for huge dataset",
340
+ )
341
+ self.parser.add_argument(
342
+ "--training.num_workers",
343
+ type=int,
344
+ default=32,
345
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
346
+ )
347
+ self.parser.add_argument(
348
+ "--training.prefetch_factor",
349
+ type=int,
350
+ default=2,
351
+ help="Number of batches loaded in advance by each worker."
352
+ "2 means there will be a total of 2 * num_workers batches prefetched across all workers.",
353
+ )
354
+ self.parser.add_argument(
355
+ "--training.data_parallel_replicate_degree",
356
+ type=int,
357
+ default=1,
358
+ help="""
359
+ The `data_parallel_replicate_degree` argument specifies the degree of
360
+ data parallelism for weight replication. When this value is greater
361
+ than 1, weights will be replicated across `data_parallel_replicate_degree`
362
+ ranks. If `data_parallel_shard_degree` is also greater than 1, the parallelism
363
+ method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
364
+ parallelism method used is DDP (Distributed Data Parallelism).
365
+ 1 means disabled.""",
366
+ )
367
+ self.parser.add_argument(
368
+ "--training.data_parallel_shard_degree",
369
+ type=int,
370
+ default=-1,
371
+ help="""
372
+ The `data_parallel_shard_degree` argument specifies the degree of data
373
+ parallelism for weight sharding. When this value is greater than 1, weights
374
+ will be sharded across `data_parallel_shard_degree` ranks. If
375
+ `data_parallel_replicate_degree` is also greater than 1, the parallelism
376
+ method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
377
+ parallelism method used is FSDP (Fully Sharded Data Parallelism).
378
+
379
+ -1 means leftover ranks will be used (After DP_REPLICATE/SP/PP). Note that
380
+ only `data_parallel_shard_degree` can be negative. 1 means disabled.""",
381
+ )
382
+ self.parser.add_argument(
383
+ "--training.enable_cpu_offload",
384
+ action="store_true",
385
+ help="""
386
+ Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP""",
387
+ )
388
+ self.parser.add_argument(
389
+ "--training.tensor_parallel_degree",
390
+ type=int,
391
+ default=1,
392
+ help="Tensor Parallelism degree. 1 means disabled.",
393
+ )
394
+ self.parser.add_argument(
395
+ "--training.disable_loss_parallel",
396
+ action="store_true",
397
+ help="Whether to apply loss parallel when sequence parallel is enabled",
398
+ )
399
+ self.parser.add_argument(
400
+ "--training.fsdp_reshard_after_forward",
401
+ type=str,
402
+ default="default",
403
+ choices=["default", "always", "never"],
404
+ help="""
405
+ `reshard_after_forward` specifies the policy for applying `reshard_after_forward`
406
+ within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward,
407
+ trading off memory and communication. See torch's `fully_shard` API for more documentation
408
+ on `reshard_after_forward`.
409
+ The supported policies include "default", "always" and "never":
410
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal
411
+ scenarios.
412
+ - "always" will enable `reshard_after_forward` for all forward passes.
413
+ - "never" will disable `reshard_after_forward` for all forward passes.
414
+ """,
415
+ )
416
+ self.parser.add_argument(
417
+ "--training.mixed_precision_param",
418
+ type=str,
419
+ default="bfloat16",
420
+ choices=["bfloat16", "float32"],
421
+ help="""
422
+ torch dtype to use for parameters when applying mixed precision via fully_shard or torch.autocast.
423
+ This feature takes effect via fully_shard when data_parallel_shard_degree > 1 or
424
+ context_parallel_degree > 1; it takes effect via torch.autocast when data_replicate_degree >= 1
425
+ and no other parallelism is enabled, i.e. under DDP or single-device training.
426
+ """,
427
+ )
428
+ self.parser.add_argument(
429
+ "--training.mixed_precision_reduce",
430
+ type=str,
431
+ default="float32",
432
+ choices=["float32"],
433
+ help="""
434
+ torch dtype to use for reductions when applying mixed precision via FSDP.
435
+ This feature only takes effect when data_parallel_shard_degree > 1
436
+ """,
437
+ )
438
+ self.parser.add_argument(
439
+ "--training.compile",
440
+ action="store_true",
441
+ help="Whether to compile the model",
442
+ )
443
+ self.parser.add_argument(
444
+ "--training.gc_freq",
445
+ type=int,
446
+ default=50,
447
+ help="Python garbage control scheduling interval, in steps",
448
+ )
449
+ self.parser.add_argument(
450
+ "--training.seed",
451
+ type=int,
452
+ default=42,
453
+ help="Choose the base RNG seed used for training",
454
+ )
455
+ self.parser.add_argument(
456
+ "--training.deterministic",
457
+ action="store_true",
458
+ help="Use deterministic algorithms wherever possible, may be slower",
459
+ )
460
+ # metrics configs
461
+ self.parser.add_argument(
462
+ "--metrics.log_freq",
463
+ type=int,
464
+ default=10,
465
+ help="How often to log metrics to TensorBoard, in iterations",
466
+ )
467
+ self.parser.add_argument(
468
+ "--metrics.enable_tensorboard",
469
+ action="store_true",
470
+ help="Whether to log metrics to TensorBoard",
471
+ )
472
+ self.parser.add_argument(
473
+ "--metrics.disable_color_printing",
474
+ action="store_true",
475
+ help="Whether to disable color printing in logs",
476
+ )
477
+ self.parser.add_argument(
478
+ "--metrics.save_tb_folder",
479
+ type=str,
480
+ default="tb",
481
+ help="Folder to dump TensorBoard states",
482
+ )
483
+ self.parser.add_argument(
484
+ "--metrics.save_for_all_ranks",
485
+ action="store_true",
486
+ default=False,
487
+ help="""
488
+ Whether to save TensorBoard/Wandb metrics only for rank 0 or for all ranks.
489
+ When this option is False and pipeline_parallel_degree is > 1, the metrics
490
+ component uses the 0th rank of the last stage pipeline group, which is the
491
+ only stage that computes loss metrics.
492
+ """,
493
+ )
494
+ self.parser.add_argument(
495
+ "--metrics.enable_wandb",
496
+ action="store_true",
497
+ help="Whether to log metrics to Weights & Biases",
498
+ )
499
+
500
+ self.parser.add_argument(
501
+ "--experimental.enable_async_tensor_parallel",
502
+ action="store_true",
503
+ help="Whether to apply async tensor parallel (currently only effective when compile is enabled)",
504
+ )
505
+ self.parser.add_argument(
506
+ "--experimental.pipeline_parallel_degree",
507
+ type=int,
508
+ default=1,
509
+ help="""
510
+ Pipeline Parallelism degree, or number of ranks. 1 means disabled.
511
+ If using looped schedules, this still specifies the number of physical ranks, not the number
512
+ of stages. Stages per rank are inferred from split points degree, and schedule.""",
513
+ )
514
+ self.parser.add_argument(
515
+ "--experimental.pipeline_parallel_split_points",
516
+ type=string_list,
517
+ nargs="+",
518
+ default=[],
519
+ help="""
520
+ Specify comma-separated names of modules to use as the beginning of a split point.
521
+
522
+ e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
523
+ the first containing all the layers up to layers.0,
524
+ the second containing layers.0 and up to layers.2,
525
+ the third containing layers.2 and all the remaining layers.
526
+
527
+ Note: fully-automated splitting may be enabled in the future,
528
+ but currently the split points must be specified manually.""",
529
+ )
530
+ self.parser.add_argument(
531
+ "--experimental.pipeline_parallel_schedule",
532
+ type=str,
533
+ default="1F1B",
534
+ help="""
535
+ Specify the Pipeline Parallel schedule to use. The supported schedules are:
536
+ https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/distributed/pipelining/schedules.py#L2161.
537
+ The schedule must be compatible with the split points and stages_per_rank.
538
+
539
+ Looped schedules (e.g. Interleaved1F1B) require specifying pipeline_parallel_degree = number of ranks,
540
+ and split_points = number of stages - 1
541
+ """,
542
+ )
543
+ self.parser.add_argument(
544
+ "--experimental.pipeline_parallel_schedule_csv",
545
+ type=str,
546
+ default="",
547
+ help="""
548
+ Specify the path to the pipeline parallel schedule csv file to use.
549
+ The pipeline_parallel_schedule argument must be either
550
+ PipelineScheduleSingle, PipelineScheduleMulti, or _PipelineScheduleRuntime.
551
+ """,
552
+ )
553
+
554
+ self.parser.add_argument(
555
+ "--experimental.pipeline_parallel_microbatches",
556
+ type=int,
557
+ default=None,
558
+ help="""
559
+ How many microbatches to split the global training batch into when using pipeline parallelism.
560
+
561
+ The global training batch size must be evenly divisible by the number of microbatches.
562
+
563
+ The default value will be the number of pipeline stages, if unspecified.
564
+ """,
565
+ )
566
+ self.parser.add_argument(
567
+ "--experimental.enable_compiled_autograd",
568
+ action="store_true",
569
+ help="Enable CompiledAutograd to compile the backward.",
570
+ )
571
+ self.parser.add_argument(
572
+ "--experimental.context_parallel_degree",
573
+ type=int,
574
+ default=1,
575
+ help="Context parallelism degree. 1 means disabled.",
576
+ )
577
+ self.parser.add_argument(
578
+ "--experimental.context_parallel_rotate_method",
579
+ type=str,
580
+ default="allgather",
581
+ help="""
582
+ The collective to use in context parallel SDPA for kv shards exchange.
583
+
584
+ 'allgather' means to all-gather all kv shards on ranks after the first sub-SDPA computation,
585
+
586
+ 'alltoall' means to all-to-all shuffle the kv shards.
587
+
588
+ The default value is 'allgather'.
589
+ """,
590
+ )
591
+ # I'm not particularly fond of this. Users can choose to write their own wrapper
592
+ # module and import TorchTitan training loop and execute it, which look cleaner.
593
+ # One reason to provide this option is to allow users to use the existing run script.
594
+ # While the script is pretty trivial now, we may add more logic when integrating
595
+ # with TorchFT.
596
+ # This option is subject to change and may be deleted in the future.
597
+ self.parser.add_argument(
598
+ "--experimental.custom_model_path",
599
+ type=str,
600
+ default="",
601
+ help="""
602
+ The --custom_model_path option allows to specify a custom path to a model module
603
+ that is not natively implemented within TorchTitan.
604
+ Acceptable values are the file system path to the module (e.g., my_models/model_x)
605
+ dotted import module (e.g., some_package.model_x).
606
+ """,
607
+ )
608
+ # checkpointing configs
609
+ self.parser.add_argument(
610
+ "--checkpoint.enable_checkpoint",
611
+ action="store_true",
612
+ help="Whether to enable checkpoint",
613
+ )
614
+ self.parser.add_argument(
615
+ "--checkpoint.folder",
616
+ type=str,
617
+ default="checkpoint",
618
+ help="""
619
+ The folder to store the checkpoints.
620
+ When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
621
+ """,
622
+ )
623
+ self.parser.add_argument(
624
+ "--checkpoint.initial_load_path", type=str, default=None,
625
+ help="""
626
+ This option specifies the path to the initial checkpoint to load, which is
627
+ particularly useful for resuming training from a previous run with a
628
+ different output path or when loading a checkpoint from a pre-trained model.
629
+ If the checkpoint folder for the current run is not empty,
630
+ located at {--job.dump_folder}/{--checkpoint.folder}, this option will be ignored.
631
+ This feature allows users to load an initial checkpoint from a different folder and
632
+ continue training, saving new checkpoints to the specified folder without affecting
633
+ the existing ones.
634
+
635
+ Note that the path should contain the full path to the checkpoint folder,
636
+ including the step number, if any; for example,
637
+ "//pre_train/checkpoints/llama3/llama3_8b/step_10000".
638
+ """
639
+ )
640
+ self.parser.add_argument(
641
+ "--checkpoint.initial_load_model_weights_only",
642
+ dest='checkpoint.initial_load_model_weights_only', action="store_true", default=True,
643
+ help="""
644
+ This option specifies if only the model weights should be loaded during the initial
645
+ checkpoint load. The option is only used when `initial_load_path` is specified, and
646
+ only applies to a model_weights_only checkpoint. Loading a periodic checkpoint
647
+ may lead to unexpected behavior if this option is set to True.
648
+ If False, the checkpoint at `initial_load_path` is treated as a standard training
649
+ checkpoint, including optimizer and training states.
650
+ The default setting for this option is True. Note that you will have to use
651
+ `--checkpoint.no_initial_load_model_weights_only` to override the default setting.
652
+ """
653
+ )
654
+ self.parser.add_argument(
655
+ "--checkpoint.no_initial_load_model_weights_only",
656
+ dest='checkpoint.initial_load_model_weights_only', action="store_false",
657
+ )
658
+ self.parser.add_argument(
659
+ "--checkpoint.interval",
660
+ type=int,
661
+ default=500,
662
+ help="Checkpointing interval in steps.",
663
+ )
664
+ self.parser.add_argument(
665
+ "--checkpoint.last_save_model_weights_only",
666
+ action="store_true",
667
+ help="""
668
+ When last_save_model_weights_only=True, only model weights will be saved at the end of training,
669
+ the last save. With this, checkpoints can be loaded using `torch.load(..., weights_only=True)`
670
+ after conversion. When last_save_model_weights_only=False, the full checkpoint will be saved.
671
+ A full checkpoint includes model, optimizer and train_state, which can be used to resume training.
672
+ The default value is false.
673
+ """,
674
+ )
675
+ self.parser.add_argument(
676
+ "--checkpoint.export_dtype",
677
+ type=str,
678
+ default="float32",
679
+ choices=["float16", "bfloat16", "float32"],
680
+ help="""
681
+ Converts to the specified precision when training completes and model_weights_only=true.
682
+ Currently supports float32, float16, and bfloat16.
683
+ The default value is float32.
684
+ """,
685
+ )
686
+ self.parser.add_argument(
687
+ "--checkpoint.create_seed_checkpoint",
688
+ action="store_true",
689
+ help="""
690
+ Initializes the full model without applying parallelisms, and then saves it as a seed checkpoint.
691
+ Note: requires user to call train.py without specifying any parallelisms, e.g. NGPU=1.
692
+ Could be implemented as a separate script, but this way shares more code.
693
+ """,
694
+ )
695
+ self.parser.add_argument(
696
+ "--checkpoint.async_mode",
697
+ type=str,
698
+ default="disabled",
699
+ help="""
700
+ Which async checkpoint mode to use. Currently there are 3 different modes.
701
+ 1. "disabled": synchronized checkpointing will be used.
702
+ 2. "async": torch.distributed.checkpoint.async_save will be used.
703
+ 3. "async_with_pinned_mem": this option utilizes a dedicated pinned memory
704
+ space and creates a separate process for faster GPU->CPU transfer
705
+ performance and eliminating GIL contention. The cost is increased CPU
706
+ memory usage. If insufficient CPU memory is available, performance may
707
+ degrade due to memory paging. For most users, "async" should suffice as
708
+ the performance overhead is typically small (on the order of tens of
709
+ seconds) compared to checkpointing frequency. This mode can be employed
710
+ to pursue near-zero checkpointing times (e.g., < 1 second) given
711
+ appropriate hardware support such as ample CPU memory and fast PCIe.
712
+
713
+ "disabled" is the default mode.
714
+ """,
715
+ )
716
+ self.parser.add_argument(
717
+ "--checkpoint.keep_latest_k",
718
+ type=int,
719
+ default=0,
720
+ help="""
721
+ Keeps only the latest k checkpoints, and purging older ones. If 0, keep all checkpoints.
722
+ 0 is the default value. k cannot be 1 as the last one may be in the process of being
723
+ saved. As a result, the metadata of the last one may not be ready yet.
724
+ """,
725
+ )
726
+ self.parser.add_argument(
727
+ "--checkpoint.load_step",
728
+ type=int,
729
+ default=-1,
730
+ help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.",
731
+ )
732
+ self.parser.add_argument(
733
+ "--checkpoint.exclude_from_loading",
734
+ type=string_list,
735
+ nargs="*",
736
+ default=[],
737
+ help="""
738
+ Exclude specific keys from being loaded from the checkpoint.
739
+ Provide a comma-separated list of keys to exclude, e.g. 'optimizer,lr_scheduler,dataloader'.
740
+ This will load the model only, excluding the specified keys.
741
+ """,
742
+ )
743
+ # activation checkpointing configs
744
+ self.parser.add_argument(
745
+ "--activation_checkpoint.mode",
746
+ type=str,
747
+ default="selective",
748
+ help="Type of activation checkpointing to use ['none', 'full', 'selective']",
749
+ )
750
+ self.parser.add_argument(
751
+ "--activation_checkpoint.selective_ac_option",
752
+ type=str,
753
+ default="2", # 2 = checkpoint every other layer
754
+ help="""
755
+ Selective activation checkpointing options ['int', 'op'].
756
+ 'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
757
+ """,
758
+ )
759
+
760
+ self.parser.add_argument(
761
+ "--activation_offload.mode",
762
+ type=str,
763
+ default="none",
764
+ help="""
765
+ if we are using activation offload or not. Options are ['none', 'full'].
766
+ """,
767
+ )
768
+
769
+ # float8 configs
770
+ self.parser.add_argument(
771
+ "--float8.enable_fsdp_float8_all_gather",
772
+ action="store_true",
773
+ help="Whether enable float8 all-gather in FSDP, recommended for tensorwise scaling",
774
+ )
775
+ self.parser.add_argument(
776
+ "--float8.precompute_float8_dynamic_scale_for_fsdp",
777
+ action="store_true",
778
+ help="Whether precompute float8 scales dynamically for FSDP, recommended for tensorwise scaling",
779
+ )
780
+ self.parser.add_argument(
781
+ "--float8.force_recompute_fp8_weight_in_bwd",
782
+ action="store_true",
783
+ help="""
784
+ Whether to force the recomputation of FP8 weights during backward pass.
785
+ When using FSDP with tensorwise scaling, it is recommended to enable
786
+ `force_recompute_fp8_weight_in_bwd` to prevent saving unsharded FP8 weights
787
+ for backward computation.
788
+ """,
789
+ )
790
+ self.parser.add_argument(
791
+ "--float8.recipe_name",
792
+ type=str,
793
+ default=None,
794
+ choices=["tensorwise", "rowwise", "rowwise_with_gw_hp"],
795
+ help="""
796
+ If specified, creates float8 config from recipe name, valid choices are
797
+ `tensorwise`, `rowwise` and `rowwise_with_gw_hp`.
798
+ """,
799
+ )
800
+
801
+ # communications library settings
802
+ self.parser.add_argument(
803
+ "--comm.init_timeout_seconds",
804
+ type=int,
805
+ default=300,
806
+ help="Timeout for communication operations, during initialization and first train step.",
807
+ )
808
+ self.parser.add_argument(
809
+ "--comm.train_timeout_seconds",
810
+ type=int,
811
+ default=100,
812
+ help=(
813
+ "Timeout for communication operations after the first train step -- "
814
+ "usually a tighter bound than during initialization."
815
+ ),
816
+ )
817
+ self.parser.add_argument(
818
+ "--comm.trace_buf_size",
819
+ type=int,
820
+ default=20000,
821
+ help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
822
+ )
823
+
824
+ # memory estimation settings
825
+ self.parser.add_argument(
826
+ "--memory_estimation.enabled",
827
+ help="Whether to estimate memory usage for FSDP",
828
+ action="store_true",
829
+ )
830
+
831
+ self.parser.add_argument(
832
+ "--memory_estimation.disable_fake_mode",
833
+ help="Whether to estimate memory under FakeTensorMode",
834
+ action="store_true",
835
+ )
836
+
837
+ self.parser.add_argument(
838
+ "--fault_tolerance.enable",
839
+ action="store_true",
840
+ help="""
841
+ Enable TorchFT integration. When TorchFT is enabled, HSDP will be used.
842
+ And --fault_tolerance.data_parallel_replicate_degree should be 1 and
843
+ --fault_tolerance.group_size will be used to control the maximum
844
+ replicate group size as the replicate group size is dynamic.
845
+
846
+ Note that this is still an experimental feature.
847
+ """,
848
+ )
849
+
850
+ self.parser.add_argument(
851
+ "--fault_tolerance.replica_id",
852
+ type=int,
853
+ default=0,
854
+ help="The TorchFT replica ID of this run.",
855
+ )
856
+
857
+ self.parser.add_argument(
858
+ "--fault_tolerance.group_size",
859
+ type=int,
860
+ default=0,
861
+ help="""
862
+ The number of TorchFT replicate groups. This number will be used for
863
+ dataloader to split the dataset across the replicate groups and FSDP
864
+ dimension
865
+ """,
866
+ )
867
+
868
+ self.parser.add_argument(
869
+ "--fault_tolerance.min_replica_size",
870
+ type=int,
871
+ default=1,
872
+ help="The minimum number of FT replica for each step.",
873
+ )
874
+
875
+ def to_dict(self):
876
+ return self.args_dict
877
+
878
+ def parse_args(self, args_list: list = sys.argv[1:]):
879
+ args, cmd_args = self.parse_args_from_command_line(args_list)
880
+ config_file = getattr(args, "job.config_file", None)
881
+ # build up a two level dict
882
+ args_dict = self._args_to_two_level_dict(args)
883
+ if config_file is not None:
884
+ try:
885
+ with open(config_file, "rb") as f:
886
+ for k, v in tomllib.load(f).items():
887
+ # to prevent overwrite of non-specified keys
888
+ args_dict[k] |= v
889
+ except (FileNotFoundError, tomllib.TOMLDecodeError) as e:
890
+ logger.exception(
891
+ f"Error while loading the configuration file: {config_file}"
892
+ )
893
+ logger.exception(f"Error details: {str(e)}")
894
+ raise e
895
+
896
+ # Checking string-list arguments are properly split into a list
897
+ # if split-points came from 'args' (from cmd line) it would have already been parsed into a list by that parser
898
+ string_list_argnames = self._get_string_list_argument_names()
899
+ for n in string_list_argnames:
900
+ check_string_list_argument(args_dict, n)
901
+
902
+ # override args dict with cmd_args
903
+ cmd_args_dict = self._args_to_two_level_dict(cmd_args)
904
+ for section, section_args in cmd_args_dict.items():
905
+ for k, v in section_args.items():
906
+ args_dict[section][k] = v
907
+
908
+ self.args_dict = args_dict
909
+
910
+ for k, v in args_dict.items():
911
+ class_type = type(k.title(), (), v)
912
+ setattr(self, k, class_type())
913
+ self._validate_config()
914
+
915
+ def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
916
+ args_dict = defaultdict(defaultdict)
917
+ for k, v in vars(args).items():
918
+ first_level_key, second_level_key = k.split(".", 1)
919
+ args_dict[first_level_key][second_level_key] = v
920
+ return args_dict
921
+
922
+ def _validate_config(self) -> None:
923
+ # TODO: Add more mandatory validations
924
+ assert self.model.config
925
+ assert self.model.tokenizer_path
926
+
927
+ def _get_string_list_argument_names(self) -> list[str]:
928
+ """Get the parser argument names of type `string_list`."""
929
+ string_list_args = [
930
+ v.dest for v in self.parser._actions if v.type is string_list
931
+ ]
932
+ return string_list_args
933
+
934
+ def parse_args_from_command_line(
935
+ self, args_list
936
+ ) -> Tuple[argparse.Namespace, argparse.Namespace]:
937
+ """
938
+ Parse command line arguments and return the parsed args and the command line only args
939
+ """
940
+ args = self.parser.parse_args(args_list)
941
+ string_list_argnames = set(self._get_string_list_argument_names())
942
+
943
+ # aux parser to parse the command line only args, with no defaults from main parser
944
+ aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
945
+ for arg, val in vars(args).items():
946
+ if isinstance(val, bool):
947
+ aux_parser.add_argument(
948
+ "--" + arg, action="store_true" if val else "store_false"
949
+ )
950
+ elif arg in string_list_argnames:
951
+ # without this special case, type inference breaks here,
952
+ # since the inferred type is just 'list' and it ends up flattening
953
+ # e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
954
+ aux_parser.add_argument("--" + arg, type=string_list)
955
+ else:
956
+ aux_parser.add_argument("--" + arg, type=type(val))
957
+
958
+ cmd_args, _ = aux_parser.parse_known_args(args_list)
959
+
960
+ return args, cmd_args
flame/data.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ import pickle
7
+ from copy import deepcopy
8
+ from dataclasses import dataclass
9
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Union
10
+
11
+ import datasets
12
+ import numpy as np
13
+ import torch
14
+ from datasets import Dataset, IterableDataset, interleave_datasets, load_dataset
15
+ from datasets.iterable_dataset import ShufflingConfig
16
+ from torch.distributed.checkpoint.stateful import Stateful
17
+ from torchdata.stateful_dataloader import StatefulDataLoader
18
+ from transformers import PreTrainedTokenizer
19
+
20
+ from torchtitan.tools import utils
21
+ from torchtitan.tools.logging import logger
22
+
23
+ datasets.logging.set_verbosity_info()
24
+
25
+ class BufferShuffledIterableDataset(IterableDataset):
26
+ def __init__(
27
+ self,
28
+ dataset: Dataset,
29
+ tokenizer: PreTrainedTokenizer,
30
+ seq_len: int = 2048,
31
+ rank: int = 0,
32
+ world_size: int = 1,
33
+ buffer_size: int = 1024,
34
+ ) -> BufferShuffledIterableDataset:
35
+ self.dataset = dataset
36
+ self.tokenizer = tokenizer
37
+
38
+ self.data = dataset.shard(world_size, rank)
39
+ self.seq_len = seq_len
40
+
41
+ self.rank = rank
42
+ self.world_size = world_size
43
+ self.buffer_size = buffer_size
44
+
45
+ if tokenizer.vocab_size < torch.iinfo(torch.uint16).max:
46
+ self.dtype = torch.uint16
47
+ elif tokenizer.vocab_size < torch.iinfo(torch.uint32).max:
48
+ self.dtype = torch.uint32
49
+ else:
50
+ self.dtype = torch.uint64
51
+ self.states = None
52
+ self.buffer = torch.tensor([], dtype=self.dtype)
53
+ self.tokens = []
54
+ self.rand_id = 0
55
+ self.token_id = 0
56
+ self.rng_state = None
57
+ self._epoch = 0
58
+
59
+ def __iter__(self):
60
+ g = torch.Generator()
61
+ g.manual_seed(self._epoch + self.rank)
62
+ if self.rng_state is not None:
63
+ g.set_state(self.rng_state)
64
+
65
+ rand_it = self.randint(0, self.buffer_size, g=g)
66
+ if self.states is not None:
67
+ self.data.load_state_dict(self.states)
68
+
69
+ # max number of tokens allowed in the chunk buffer
70
+ n_tokens = self.buffer_size * self.seq_len
71
+
72
+ while True:
73
+ for sample in self.tokenize(self.data):
74
+ # keep appending the samples to the token buffer
75
+ self.tokens += sample
76
+ # if the token buffer is full, start sampling
77
+ # NOTE: we first convert the token ids to a tensor of shape [n_chunks, seq_len] for efficiency
78
+ if len(self.buffer) == 0 and len(self.tokens) >= n_tokens:
79
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=self.dtype).view(self.buffer_size, -1)
80
+ self.tokens = self.tokens[n_tokens:]
81
+ if len(self.buffer) == self.buffer_size:
82
+ yield from self.sample(rand_it)
83
+
84
+ n_chunks = len(self.tokens) // self.seq_len
85
+ # handle the left tokens in the buffer
86
+ if n_chunks > 0:
87
+ n_tokens = n_chunks * self.seq_len
88
+ indices = torch.randperm(n_chunks, generator=g).tolist()
89
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=torch.long).view(n_chunks, -1)
90
+ self.tokens = self.tokens[n_tokens:]
91
+ for i in indices:
92
+ yield {'input_ids': self.buffer[i]}
93
+
94
+ def tokenize(self, data, batch_size: int = 64):
95
+ texts, states = [], []
96
+ for sample in data:
97
+ texts.append(sample['text'])
98
+ states.append(self.data.state_dict())
99
+ if len(texts) == batch_size:
100
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
101
+ self.states = s
102
+ yield tokenized
103
+ texts, states = [], []
104
+ if len(texts) > 0:
105
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
106
+ self.states = s
107
+ yield tokenized
108
+
109
+ def sample(self, indices):
110
+ n_tokens = (len(self.tokens) // self.seq_len) * self.seq_len
111
+ while self.token_id < n_tokens:
112
+ i = next(indices)
113
+ start, end = self.token_id, self.token_id + self.seq_len
114
+ self.token_id += self.seq_len
115
+ yield {'input_ids': self.buffer[i].to(torch.long)}
116
+ self.buffer[i] = torch.tensor(self.tokens[start:end], dtype=self.dtype)
117
+ self.token_id = 0
118
+ self.tokens = self.tokens[n_tokens:]
119
+
120
+ def randint(self, low: int, high: int, buffer_size: int = 1024, g: torch.Generator = torch.Generator()) -> Iterable[int]:
121
+ indices = torch.empty(buffer_size, dtype=torch.long)
122
+ while True:
123
+ # record the generator states before sampling
124
+ self.rng_state = g.get_state()
125
+ indices = torch.randint(low, high, (buffer_size,), out=indices, generator=g)
126
+ for i in indices[self.rand_id:].tolist():
127
+ self.rand_id += 1
128
+ yield i
129
+ self.rand_id = 0
130
+
131
+ def set_epoch(self, epoch):
132
+ self._epoch = epoch
133
+ if hasattr(self.dataset, 'set_epoch'):
134
+ self.dataset.set_epoch(epoch)
135
+
136
+ def state_dict(self):
137
+ return {
138
+ 'states': self.states,
139
+ 'buffer': self.buffer.clone(),
140
+ 'tokens': deepcopy(self.tokens),
141
+ 'rand_id': self.rand_id,
142
+ 'token_id': self.token_id,
143
+ 'rng_state': self.rng_state,
144
+ 'epoch': self._epoch,
145
+ }
146
+
147
+ def load_state_dict(self, state_dict):
148
+ self.states = state_dict['states']
149
+ self.buffer = state_dict['buffer'].clone()
150
+ self.tokens = deepcopy(state_dict['tokens'])
151
+ self.rand_id = state_dict['rand_id']
152
+ self.token_id = state_dict['token_id']
153
+ self.rng_state = state_dict['rng_state'].clone() if state_dict['rng_state'] is not None else None
154
+ self._epoch = state_dict['epoch']
155
+
156
+
157
+ class OnlineTokenizedIterableDataset(IterableDataset):
158
+ def __init__(
159
+ self, dataset: Dataset, tokenizer: PreTrainedTokenizer, seq_len: int = 2048, rank: int = 0, world_size: int = 1
160
+ ) -> OnlineTokenizedIterableDataset:
161
+ self.dataset = dataset
162
+ self.tokenizer = tokenizer
163
+
164
+ self.data = dataset.shard(world_size, rank)
165
+ self.seq_len = seq_len
166
+ self.rank = rank
167
+ self.world_size = world_size
168
+
169
+ self.states = None
170
+ self.tokens = []
171
+
172
+ def __iter__(self):
173
+ if self.states is not None:
174
+ self.data.load_state_dict(self.states)
175
+
176
+ while True:
177
+ for sample in self.tokenize(self.data):
178
+ # keep appending the samples to the token buffer
179
+ self.tokens += sample
180
+
181
+ while len(self.tokens) >= self.seq_len:
182
+ input_ids = torch.tensor(self.tokens[:self.seq_len], dtype=torch.long)
183
+ self.tokens = self.tokens[self.seq_len:]
184
+ yield {'input_ids': input_ids}
185
+
186
+ def tokenize(self, data, buffer_size: int = 64):
187
+ buffer, states = [], []
188
+ for sample in data:
189
+ if sample.get('text', None) is not None:
190
+ buffer.append(sample['text'])
191
+ elif sample.get('content', None) is not None:
192
+ buffer.append(sample['content'])
193
+ else:
194
+ raise ValueError(f"No 'text' or 'content' field found in sample:\n{sample}")
195
+ states.append(self.data.state_dict())
196
+ if len(buffer) == buffer_size:
197
+ for s, tokenized in zip(states, self.tokenizer(buffer, return_attention_mask=False)['input_ids']):
198
+ self.states = s
199
+ yield tokenized
200
+ buffer, states = [], []
201
+ if len(buffer) > 0:
202
+ for s, tokenized in zip(states, self.tokenizer(buffer, return_attention_mask=False)['input_ids']):
203
+ self.states = s
204
+ yield tokenized
205
+
206
+ def state_dict(self):
207
+ return {'states': self.states, 'tokens': deepcopy(self.tokens)}
208
+
209
+ def load_state_dict(self, state_dict):
210
+ self.states = state_dict['states']
211
+ self.tokens = deepcopy(state_dict['tokens'])
212
+
213
+
214
+ class BufferShuffledExamplesIterable(datasets.iterable_dataset.BufferShuffledExamplesIterable):
215
+ def __init__(self, *args, **kwargs):
216
+ super().__init__(*args, **kwargs)
217
+
218
+ def _init_state_dict(self) -> dict:
219
+ self._state_dict = self.ex_iterable._init_state_dict()
220
+ self._state_dict['mem_buffer'] = ([],)
221
+ self._state_dict['bit_generator_state'] = self.generator.bit_generator.state
222
+ self._state_dict['bit_generator_index_offset'] = 0
223
+ self._state_dict['bit_generator_index_offset_shuffle'] = 0
224
+ return self._state_dict
225
+
226
+ def __iter__(self):
227
+ buffer_size = self.buffer_size
228
+ rng = deepcopy(self.generator)
229
+ # this is the shuffle buffer that we keep in memory
230
+ mem_buffer = self._state_dict['mem_buffer'][0]
231
+ # this is an infinite iterator that randomly samples the index of the source to pick examples from
232
+ index_offset = self._state_dict['bit_generator_index_offset'] if self._state_dict else 0
233
+ if self._state_dict:
234
+ rng.bit_generator.state = self._state_dict['bit_generator_state']
235
+ indices_iterator = self._iter_random_indices(rng, buffer_size, random_batch_size=buffer_size)
236
+ # skip already consumed ones
237
+ for _ in range(index_offset):
238
+ i = next(indices_iterator)
239
+
240
+ for x in self.ex_iterable:
241
+ if len(mem_buffer) < buffer_size: # if the buffer is not full, keep filling the buffer
242
+ mem_buffer.append(x)
243
+ else: # otherwise, pick an example from it
244
+ i = next(indices_iterator)
245
+ index_offset = (index_offset + 1) % buffer_size
246
+ if self._state_dict:
247
+ self._state_dict['bit_generator_index_offset'] = index_offset
248
+ if index_offset == 0:
249
+ self._state_dict['bit_generator_state'] = rng.bit_generator.state
250
+ selected = mem_buffer[i]
251
+ mem_buffer[i] = x # replace the picked example by a new one
252
+ yield selected
253
+
254
+ index_offset = self._state_dict['bit_generator_index_offset_shuffle'] if self._state_dict else 0
255
+ if self._state_dict:
256
+ rng.bit_generator.state = self._state_dict['bit_generator_state']
257
+
258
+ # when we run out of examples, we shuffle the remaining examples in the buffer and yield them
259
+ for i in rng.permutation(len(mem_buffer))[index_offset:].tolist():
260
+ index_offset = index_offset + 1
261
+ if self._state_dict:
262
+ self._state_dict['bit_generator_index_offset_shuffle'] = index_offset
263
+ yield mem_buffer[i]
264
+
265
+ def shuffle_data_sources(self, generator: np.random.Generator) -> BufferShuffledExamplesIterable:
266
+ """Shuffle the wrapped examples iterable as well as the shuffling buffer."""
267
+ return BufferShuffledExamplesIterable(
268
+ self.ex_iterable.shuffle_data_sources(generator), buffer_size=self.buffer_size, generator=generator
269
+ )
270
+
271
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> BufferShuffledExamplesIterable:
272
+ """Keep only the requested shard."""
273
+ return BufferShuffledExamplesIterable(
274
+ self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
275
+ buffer_size=self.buffer_size,
276
+ generator=self.generator,
277
+ )
278
+
279
+ def load_state_dict(self, state_dict: dict) -> dict:
280
+ def _inner_load_state_dict(state, new_state):
281
+ if new_state is not None and isinstance(state, dict):
282
+ for key in new_state:
283
+ state[key] = _inner_load_state_dict(state[key], new_state[key])
284
+ return state
285
+ elif new_state is not None and isinstance(state, list):
286
+ for i in range(len(state)):
287
+ state[i] = _inner_load_state_dict(state[i], new_state[i])
288
+ return state
289
+ return new_state
290
+
291
+ return _inner_load_state_dict(self._state_dict, state_dict)
292
+
293
+
294
+ def shuffle(
295
+ dataset: IterableDataset,
296
+ seed: int = 42,
297
+ generator: np.random.Generator = None,
298
+ buffer_size: int = 1024,
299
+ ):
300
+ generator = np.random.default_rng(seed) if generator is None else deepcopy(generator)
301
+ return IterableDataset(
302
+ ex_iterable=BufferShuffledExamplesIterable(dataset._ex_iterable, buffer_size=buffer_size, generator=generator),
303
+ info=dataset._info.copy(),
304
+ split=dataset._split,
305
+ formatting=dataset._formatting,
306
+ shuffling=ShufflingConfig(generator=generator, _original_seed=seed),
307
+ distributed=copy.deepcopy(dataset._distributed),
308
+ token_per_repo_id=dataset._token_per_repo_id,
309
+ )
310
+
311
+
312
+ @dataclass
313
+ class DataCollatorForLanguageModeling:
314
+ """
315
+ Data collator used for language modeling. Inputs are dynamically padded if `varlen=False`.
316
+ If `varlen=True`, sequences are expected to be concatenated, and labels match inputs.
317
+
318
+ Args:
319
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
320
+ The tokenizer used for encoding the data.
321
+ context_len (`int`, optional):
322
+ When `varlen=True`, sequences longer than this length within a document
323
+ (as determined by `cu_seqlens`) will be further chunked.
324
+ varlen (`bool`):
325
+ Whether to handle variable length concatenated sequences (`True`) or padded batches (`False`).
326
+
327
+ Returns:
328
+ A dictionary with the following keys:
329
+ - `input_ids`: Tensor of input IDs. Shape `[batch_size, seq_len]` if `varlen=False`, `[1, total_len]` if `varlen=True`.
330
+ - `labels`: Tensor of labels. Shape matches `input_ids`. Padding positions are masked with -100 if `varlen=False`.
331
+ - `attention_mask`: Tensor indicating non-padding tokens (only if `varlen=False`). Shape matches `input_ids`.
332
+ - `cu_seqlens`: Tensor of cumulative sequence lengths (only if `varlen=True`). Shape `[1, num_sequences + 1]`.
333
+
334
+ NOTE: When `varlen=True`, the `batch_size` must be 1.
335
+ """
336
+
337
+ tokenizer: PreTrainedTokenizer
338
+ context_len: Optional[int] = None
339
+ varlen: bool = False
340
+
341
+ def __call__(self, examples: List[Union[List[int], Dict[str, Any]]]) -> Dict[str, Any]:
342
+ if not isinstance(examples[0], Dict):
343
+ examples = [{'input_ids': example} for example in examples]
344
+
345
+ def tensorize(example: Dict[str, Any]) -> Dict[str, Any]:
346
+ tensorized = {}
347
+ for key in ['input_ids', 'cu_seqlens']:
348
+ if key not in example:
349
+ continue
350
+ if isinstance(example[key], List):
351
+ tensorized[key] = torch.tensor(example[key], dtype=torch.long)
352
+ elif isinstance(example[key], np.ndarray):
353
+ tensorized[key] = torch.from_numpy(example[key])
354
+ else:
355
+ tensorized[key] = example[key]
356
+ return tensorized
357
+
358
+ examples = list(map(tensorize, examples))
359
+
360
+ if not self.varlen:
361
+ # --- Handling for varlen=False (Batch Padding) ---
362
+ length_of_first = examples[0]['input_ids'].size(0)
363
+ needs_padding = not all(example['input_ids'].size(0) == length_of_first for example in examples)
364
+
365
+ if needs_padding:
366
+ # Check for pad token if padding is actually required
367
+ if self.tokenizer.pad_token_id is None:
368
+ raise ValueError(
369
+ f'You are attempting to pad samples but the tokenizer you are using '
370
+ f'({self.tokenizer.__class__.__name__}) does not have a pad token.'
371
+ )
372
+ # Pad using the tokenizer, ensuring attention_mask is returned
373
+ batch = self.tokenizer.pad(examples, return_tensors='pt', return_attention_mask=True)
374
+ else:
375
+ # No padding needed, stack directly and create a full attention mask
376
+ input_ids = torch.stack([example['input_ids'] for example in examples], dim=0)
377
+ batch = {
378
+ 'input_ids': input_ids,
379
+ # Create attention mask of all ones
380
+ 'attention_mask': torch.ones_like(input_ids),
381
+ }
382
+
383
+ # Create labels by cloning input_ids
384
+ labels = batch['input_ids'].clone()
385
+ # Mask labels only where attention_mask is 0 (padding positions)
386
+ if 'attention_mask' in batch:
387
+ labels[batch['attention_mask'] == 0] = -100
388
+ batch['labels'] = labels
389
+
390
+ else:
391
+ # --- Handling for varlen=True (Concatenated Sequences) ---
392
+ if len(examples) > 1:
393
+ raise ValueError('The batch size must be 1 for inputs with variable lengths (varlen=True).')
394
+
395
+ batch = {'input_ids': torch.cat([example['input_ids'] for example in examples], dim=0).unsqueeze(0)}
396
+
397
+ # --- cu_seqlens calculation logic remains the same ---
398
+ if 'cu_seqlens' in examples[0]:
399
+ batch['cu_seqlens'] = (
400
+ torch.cat([example['cu_seqlens'] for example in examples], dim=0).unsqueeze(0).to(dtype=torch.int32)
401
+ ) # Ensure int32
402
+ else:
403
+ # determine boundaries by bos/eos positions
404
+ # Check for bos_token_id first
405
+ if self.tokenizer.bos_token_id is not None:
406
+ cu_seqlens = []
407
+ # Handle case where the sequence doesn't start with BOS
408
+ if batch['input_ids'][0, 0] != self.tokenizer.bos_token_id:
409
+ cu_seqlens.append(torch.tensor([0], device=batch['input_ids'].device)) # Match device
410
+ # Find all BOS token positions
411
+ bos_positions = torch.where(batch['input_ids'].eq(self.tokenizer.bos_token_id))[1]
412
+ # Ensure bos_positions is on the correct device if empty
413
+ if bos_positions.numel() == 0 and len(cu_seqlens) > 0:
414
+ cu_seqlens.append(bos_positions.to(cu_seqlens[0].device))
415
+ elif bos_positions.numel() > 0:
416
+ cu_seqlens.append(bos_positions)
417
+ # Add the end of the entire batch
418
+ cu_seqlens.append(
419
+ torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
420
+ ) # Match device and use size(1)
421
+ # Filter out empty tensors before cat
422
+ cu_seqlens = [t for t in cu_seqlens if t.numel() > 0]
423
+ if not cu_seqlens: # Handle case where input is empty or has no BOS
424
+ batch['cu_seqlens'] = torch.tensor(
425
+ [0, batch['input_ids'].size(1)], dtype=torch.int32, device=batch['input_ids'].device
426
+ )
427
+ else:
428
+ batch['cu_seqlens'] = torch.cat(cu_seqlens, dim=0).to(dtype=torch.int32)
429
+
430
+ # Else, check for eos_token_id
431
+ elif self.tokenizer.eos_token_id is not None:
432
+ cu_seqlens = [torch.tensor([0], device=batch['input_ids'].device)] # Match device
433
+ # Find positions *after* EOS tokens
434
+ eos_positions = torch.where(batch['input_ids'].eq(self.tokenizer.eos_token_id))[1] + 1
435
+ # Ensure eos_positions is on the correct device if empty
436
+ if eos_positions.numel() > 0:
437
+ cu_seqlens.append(eos_positions)
438
+ # Handle case where the sequence doesn't end with EOS
439
+ if batch['input_ids'][0, -1] != self.tokenizer.eos_token_id:
440
+ # Only add the final length if the last found EOS wasn't already the end
441
+ if eos_positions.numel() == 0 or eos_positions[-1] != batch['input_ids'].size(1):
442
+ cu_seqlens.append(
443
+ torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
444
+ ) # Match device and use size(1)
445
+ # Filter out empty tensors before cat
446
+ cu_seqlens = [t for t in cu_seqlens if t.numel() > 0]
447
+ if not cu_seqlens: # Handle case where input is empty or has no EOS
448
+ batch['cu_seqlens'] = torch.tensor(
449
+ [0, batch['input_ids'].size(1)], dtype=torch.int32, device=batch['input_ids'].device
450
+ )
451
+ else:
452
+ batch['cu_seqlens'] = torch.cat(cu_seqlens, dim=0).to(dtype=torch.int32)
453
+ # Else, neither BOS nor EOS is usable
454
+ else:
455
+ raise ValueError(
456
+ 'For varlen=True without precomputed cu_seqlens, the tokenizer must have either a bos_token_id '
457
+ 'or an eos_token_id defined to act as sequence separators.'
458
+ )
459
+
460
+ # --- cu_seqlens validation checks remain the same ---
461
+ if batch['cu_seqlens'].numel() < 2:
462
+ raise ValueError(f'Calculated cu_seqlens must have at least start and end: {batch["cu_seqlens"]}')
463
+ if not torch.all(batch['cu_seqlens'][1:] >= batch['cu_seqlens'][:-1]):
464
+ raise ValueError(f'Calculated cu_seqlens are not monotonically increasing: {batch["cu_seqlens"]}')
465
+ if batch['cu_seqlens'][0] != 0:
466
+ raise ValueError(f'Calculated cu_seqlens do not start at 0: {batch["cu_seqlens"]}')
467
+ if batch['cu_seqlens'][-1] != batch['input_ids'].size(1):
468
+ # Allow empty sequence case where cu_seqlens=[0, 0] and input_ids.size(1)=0
469
+ if not (batch['cu_seqlens'].tolist() == [0, 0] and batch['input_ids'].size(1) == 0):
470
+ raise ValueError(
471
+ f'Calculated cu_seqlens do not end at total length {batch["input_ids"].size(1)}: '
472
+ f'{batch["cu_seqlens"]}'
473
+ )
474
+
475
+ # --- context_len splitting logic remains the same ---
476
+ if self.context_len is not None:
477
+ # This logic splits sequences based on context_len *after* initial boundaries are found
478
+ bos = batch['cu_seqlens'][:-1].tolist()
479
+ eos = batch['cu_seqlens'][1:].tolist()
480
+ # Handle empty sequences between boundaries
481
+ split_boundaries = []
482
+ for i, j in zip(bos, eos):
483
+ if i < j: # Only process non-empty sequences
484
+ split_boundaries.append(torch.arange(i, j, self.context_len, device=batch['input_ids'].device))
485
+ # Add the final end point if it wasn't included by arange
486
+ final_end_point = torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
487
+ # Concatenate all boundaries
488
+ if not split_boundaries: # Handle case of completely empty input
489
+ batch['cu_seqlens'] = torch.tensor([0, 0], dtype=torch.int32, device=batch['input_ids'].device)
490
+ else:
491
+ batch['cu_seqlens'] = torch.cat(split_boundaries + [final_end_point]).to(dtype=torch.int32)
492
+ # Ensure uniqueness and sort, as arange might duplicate the endpoint
493
+ batch['cu_seqlens'] = torch.unique(batch['cu_seqlens'])
494
+
495
+ # Create labels directly from input_ids, NO padding mask needed for varlen
496
+ labels = batch['input_ids'].clone()
497
+ batch['labels'] = labels
498
+
499
+ return batch
500
+
501
+
502
+ class ParallelAwareDataLoader(StatefulDataLoader, Stateful):
503
+ """
504
+ A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
505
+ """
506
+
507
+ def __init__(
508
+ self,
509
+ rank: int,
510
+ dataset: IterableDataset,
511
+ batch_size: int,
512
+ collate_fn: Callable,
513
+ num_workers: int = 0,
514
+ pin_memory: bool = False,
515
+ prefetch_factor: int = 2,
516
+ persistent_workers: bool = False,
517
+ snapshot_every_n_steps: Optional[int] = 1,
518
+ ):
519
+ super().__init__(
520
+ dataset=dataset,
521
+ batch_size=batch_size,
522
+ collate_fn=collate_fn,
523
+ num_workers=num_workers,
524
+ pin_memory=pin_memory,
525
+ prefetch_factor=prefetch_factor,
526
+ persistent_workers=persistent_workers,
527
+ snapshot_every_n_steps=snapshot_every_n_steps,
528
+ )
529
+ self.rank = rank
530
+
531
+ def state_dict(self) -> Dict[str, Any]:
532
+ # Store state only for dp rank to avoid replicating the same state across other dimensions
533
+ return {f'rank_{self.rank}': pickle.dumps(super().state_dict())}
534
+
535
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
536
+ # State being empty is valid
537
+ if not state_dict:
538
+ return
539
+
540
+ if f'rank_{self.rank}' not in state_dict:
541
+ logger.warning(f'DataLoader state is empty for dp rank {self.rank}, expected key rank_{self.rank}')
542
+ return
543
+ super().load_state_dict(pickle.loads(state_dict[f'rank_{self.rank}']))
544
+
545
+
546
+ def build_dataset(
547
+ dataset: str,
548
+ dataset_name: str = None,
549
+ dataset_split: str = 'train',
550
+ data_dir: str = None,
551
+ data_files: str = None,
552
+ data_probs: List[float] = None,
553
+ streaming: bool = False,
554
+ dp_degree: Optional[int] = None,
555
+ num_workers: int = 32,
556
+ seed: Optional[int] = None,
557
+ ) -> IterableDataset:
558
+ color = utils.Color
559
+ min_num_shards = dp_degree * num_workers if dp_degree else None
560
+ if len(dataset.split(',')) == 1:
561
+ dataset = load_dataset(
562
+ path=dataset,
563
+ name=dataset_name,
564
+ split=dataset_split,
565
+ data_dir=data_dir,
566
+ data_files=data_files,
567
+ # trust_remote_code=True,
568
+ streaming=streaming,
569
+ num_proc=num_workers if not streaming else None,
570
+ )
571
+ logger.info(f"Shuffling the dataset with seed {seed}")
572
+ if not streaming:
573
+ # the states of map-style dataset is recoverable after shuffling
574
+ if seed is not None:
575
+ dataset = dataset.shuffle(seed=seed)
576
+ if min_num_shards is not None:
577
+ dataset = dataset.to_iterable_dataset(num_shards=min_num_shards)
578
+ else:
579
+ if min_num_shards is not None and dataset.num_shards < min_num_shards:
580
+ logger.warning(
581
+ f"{color.red}"
582
+ f"Dataset {dataset} has insufficient shards ({dataset.num_shards}). "
583
+ f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × "
584
+ f"{num_workers} dataloader workers. "
585
+ f"Disabling the streaming mode and resharding dataset to {min_num_shards} shards."
586
+ f"{color.reset}"
587
+ )
588
+ dataset = load_dataset(
589
+ path=dataset,
590
+ name=dataset_name,
591
+ split=dataset_split,
592
+ data_dir=data_dir,
593
+ data_files=data_files,
594
+ # trust_remote_code=True,
595
+ streaming=False,
596
+ num_proc=num_workers,
597
+ )
598
+ if seed is not None:
599
+ dataset = dataset.shuffle(seed=seed)
600
+ dataset = dataset.to_iterable_dataset(num_shards=min_num_shards)
601
+ else:
602
+ if seed is not None:
603
+ dataset = shuffle(dataset, seed=seed)
604
+ else:
605
+ datasets = dataset.split(",")
606
+ if dataset_name is not None:
607
+ dataset_names = [
608
+ name or None for name in dataset_name.split(",")
609
+ ]
610
+ assert len(dataset_names) == len(datasets), (
611
+ "The number of dataset names must match the number of datasets"
612
+ )
613
+ else:
614
+ dataset_names = [None] * len(datasets)
615
+ if dataset_split is not None:
616
+ dataset_splits = [split or "train"for split in dataset_split.split(",")]
617
+ assert len(dataset_splits) == len(datasets), (
618
+ "The number of dataset splits must match the number of datasets"
619
+ )
620
+ else:
621
+ dataset_splits = ["train"] * len(datasets)
622
+ if data_dir is not None:
623
+ data_dirs = [
624
+ data_dir or None for data_dir in data_dir.split(",")
625
+ ]
626
+ assert len(data_dirs) == len(datasets), (
627
+ "The number of data dirs must match the number of datasets"
628
+ )
629
+ else:
630
+ data_dirs = [None] * len(datasets)
631
+ if data_files is not None:
632
+ data_files = data_files.split(",")
633
+ assert len(data_files) == len(datasets), (
634
+ "The number of data files must match the number of datasets"
635
+ )
636
+ else:
637
+ data_files = [None] * len(datasets)
638
+ if data_probs is not None:
639
+ data_probs = [float(p) for p in data_probs.split(",")]
640
+ assert len(data_probs) == len(datasets), (
641
+ "The number of data probabilities must match the number of datasets"
642
+ )
643
+ else:
644
+ raise ValueError(
645
+ "Data sampling probabilities are required if using multiple datasets"
646
+ )
647
+
648
+ subsets = []
649
+ for i, prob in enumerate(data_probs):
650
+ subset = load_dataset(
651
+ path=datasets[i],
652
+ name=dataset_names[i],
653
+ split=dataset_splits[i],
654
+ data_dir=data_dirs[i],
655
+ data_files=data_files[i],
656
+ # trust_remote_code=True,
657
+ streaming=streaming,
658
+ num_proc=(
659
+ num_workers
660
+ if not streaming
661
+ else None
662
+ ),
663
+ )
664
+ logger.info(
665
+ f"Subset {color.cyan}{datasets[i]}"
666
+ + (f":{dataset_names[i]} " if dataset_names[i] else " ")
667
+ + f"(p = {prob:.3f}){color.reset}:\n"
668
+ + f"{subset}"
669
+ )
670
+
671
+ logger.info(f"Shuffling the dataset with seed {seed}")
672
+ if not streaming:
673
+ # the states of map-style dataset is recoverable after shuffling
674
+ if seed is not None:
675
+ subset = subset.shuffle(seed=seed)
676
+ if min_num_shards is not None:
677
+ subset = subset.to_iterable_dataset(num_shards=min_num_shards)
678
+ else:
679
+ if min_num_shards is not None and subset.num_shards < min_num_shards:
680
+ logger.warning(
681
+ f"{color.red}"
682
+ f"Dataset {datasets[i]} has insufficient shards ({subset.num_shards}). "
683
+ f"Need {min_num_shards} shards minimum for desired data parallel workers × "
684
+ f"{num_workers} dataloader workers. "
685
+ f"Resharding dataset to {min_num_shards} shards and disabling streaming mode."
686
+ f"{color.reset}"
687
+ )
688
+ # again, it's ok to directly shuffle the map-style dataset
689
+ # we expect an error raised if the map-style dataset still has not enough data shards
690
+ subset = load_dataset(
691
+ path=datasets[i],
692
+ name=dataset_names[i],
693
+ split=dataset_splits[i],
694
+ data_dir=data_dirs[i],
695
+ data_files=data_files[i],
696
+ # trust_remote_code=True,
697
+ streaming=False,
698
+ num_proc=num_workers,
699
+ )
700
+ if seed is not None:
701
+ subset = subset.shuffle(seed=seed)
702
+ subset = subset.to_iterable_dataset(num_shards=min_num_shards)
703
+ else:
704
+ # we set relatively small buffer size here as interleaving could provide some randomness
705
+ if seed is not None:
706
+ subset = shuffle(subset, seed=seed, buffer_size=max(128, 1024 // len(datasets)))
707
+
708
+ if "text" in subset.column_names:
709
+ subset = subset.select_columns("text")
710
+ elif "content" in subset.column_names:
711
+ subset = subset.select_columns("content")
712
+ else:
713
+ raise ValueError(
714
+ f"Subset {datasets[i]} has no 'text' or 'content' column"
715
+ )
716
+ subsets.append(subset)
717
+
718
+ logger.info(
719
+ f"Interleaving {len(subsets)} datasets with probabilities {data_probs}"
720
+ )
721
+ dataset = interleave_datasets(
722
+ datasets=subsets,
723
+ probabilities=data_probs,
724
+ stopping_strategy="all_exhausted",
725
+ seed=seed,
726
+ )
727
+ logger.info(f"{dataset}")
728
+ return dataset
729
+
730
+
731
+ def build_dataloader(
732
+ dataset: IterableDataset,
733
+ tokenizer: PreTrainedTokenizer,
734
+ rank: int,
735
+ world_size: int,
736
+ batch_size: int,
737
+ seq_len: int,
738
+ context_len: Optional[int] = None,
739
+ varlen: bool = False,
740
+ num_workers: int = 0,
741
+ pin_memory: bool = False,
742
+ persistent_workers: bool = False,
743
+ snapshot_every_n_steps: Optional[int] = 1,
744
+ ):
745
+ dataset = OnlineTokenizedIterableDataset(
746
+ dataset=dataset, tokenizer=tokenizer, seq_len=seq_len, rank=rank, world_size=world_size
747
+ )
748
+ return ParallelAwareDataLoader(
749
+ rank=rank,
750
+ dataset=dataset,
751
+ batch_size=batch_size,
752
+ collate_fn=DataCollatorForLanguageModeling(tokenizer=tokenizer, context_len=context_len, varlen=varlen),
753
+ num_workers=num_workers,
754
+ pin_memory=pin_memory,
755
+ persistent_workers=persistent_workers,
756
+ snapshot_every_n_steps=snapshot_every_n_steps,
757
+ )
flame/models/__init__.py ADDED
File without changes
flame/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (155 Bytes). View file
 
flame/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (143 Bytes). View file
 
flame/models/__pycache__/parallelize_fla.cpython-311.pyc ADDED
Binary file (23.6 kB). View file
 
flame/models/__pycache__/parallelize_fla.cpython-312.pyc ADDED
Binary file (22.1 kB). View file
 
flame/models/__pycache__/pipeline_fla.cpython-311.pyc ADDED
Binary file (6.37 kB). View file
 
flame/models/__pycache__/pipeline_fla.cpython-312.pyc ADDED
Binary file (5.75 kB). View file
 
flame/models/activation_offloading.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/pytorch/torchtune/blob/main/torchtune/training/_activation_offloading.py
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import contextlib
9
+ from typing import Union
10
+ from warnings import warn
11
+
12
+ import psutil
13
+ import torch
14
+ from torch import nn
15
+ from torch.autograd.graph import saved_tensors_hooks
16
+
17
+ from torchtitan.tools.logging import logger
18
+
19
+ try:
20
+ import torchao
21
+ from torchao.dtypes.nf4tensor import NF4Tensor
22
+ except ImportError:
23
+ torchao = None
24
+ NF4Tensor = None
25
+ logger.warning("torchao not found. ")
26
+
27
+ # from torchtune.modules import TiedLinear
28
+
29
+
30
+ class OffloadActivations(saved_tensors_hooks):
31
+ """Context manager under which activation tensors created in the forward pass will be offloaded.
32
+
33
+ Enable the memory efficiency technique of activation offloading, where activations bigger than
34
+ min_offload_size bytes will be offloaded to CPU in the forward and brought back in the backward.
35
+ This is in contrast to maintaining the activation on GPU VRAM throughout the program.
36
+
37
+ This manager contains the option of using one additional CUDA stream to handle the communication
38
+ between CUDA and CPU, which is intended to overlap with the default computation stream to improve
39
+ runtime. We designed synchronization with a few heuristics for optimizing the tradeoff between
40
+ runtime vs memory usage.
41
+
42
+ Args:
43
+ use_pin_memory (bool): Whether or not the offloaded Tensor will be placed in pinned
44
+ memory on the CPU. Pinned memory allows the Tensor to be moved back onto GPU more quickly
45
+ but is a limited resource. Default: True.
46
+
47
+ use_streams (bool): Whether or not to use streams for performance optimization where
48
+ the communications get overlapped with the computation. Requires a torch build
49
+ after torch-2.5.0.]. Default: True.
50
+
51
+ max_fwd_stash_size (int): The maximum size of the forward stash, or the maximum number of
52
+ consecutive activations to keep alive during the forward pass. This number must be at
53
+ least 1. Keeping alive more activations will potentially allow more overlap between the
54
+ communication and compute streams at the cost of increasing memory usage. Keeping alive
55
+ fewer activations will conserve memory, but may cause poor overlap between the streams,
56
+ increasing runtime. Default: 5.
57
+
58
+ min_offload_size (int): The minimum number of bytes a Tensor must be in order to qualify
59
+ for offloading. If the tensor is too small, we do not want to waste bandwidth and resources
60
+ moving it to CPU and back. Default: 1024 bytes.
61
+
62
+ Raises:
63
+ ValueError: if max_fwd_stash_size is not at least 1.
64
+
65
+ Example:
66
+ >>> with OffloadActivations():
67
+ >>> logits = model(inputs)
68
+ >>> loss = ...
69
+ >>> loss.backward()
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ use_pin_memory: bool = True,
75
+ use_streams: bool = True,
76
+ max_fwd_stash_size: int = 5,
77
+ min_offload_size: int = 1024,
78
+ ) -> None:
79
+
80
+ self.use_streams: bool = use_streams
81
+
82
+ self.min_tensor_size_bytes = (
83
+ min_offload_size # we don't want to bother with small tensors
84
+ )
85
+ self.tracker = (
86
+ {}
87
+ ) # tensor_id => (new_tensor, if_modified) ---> track what saved/offloaded tensors are where
88
+ self.tensor_id: int = 0
89
+ self.is_first_forward_call = True
90
+ self.is_first_backward_call = True
91
+ self.is_first_forward_pass = True
92
+
93
+ # managing cpu memory
94
+ self.use_pin_memory: bool = use_pin_memory
95
+ self.virtual_memory_safe_pct = (
96
+ 60 # we should not exceed this percentage of memory
97
+ )
98
+
99
+ self.s0 = torch.cuda.default_stream() # comp stream
100
+
101
+ # for streaming
102
+ if self.use_streams:
103
+ self.s1 = torch.cuda.Stream() # comms stream
104
+ self.fwd_stash = {} # tensor_id => (activation, ev1)
105
+ if max_fwd_stash_size < 1:
106
+ raise ValueError(
107
+ f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}"
108
+ )
109
+ self.max_fwd_stash_size = max_fwd_stash_size
110
+ self.bwd_tensor_stash = {} # tensor_id => activation
111
+ self.bwd_ev_stash = {} # tensor_id => ev0
112
+ self.curr_graph_id = None
113
+ self.curr_autograd_node = None
114
+
115
+ # -------- platform util functions -------- #
116
+ def verify_sufficient_virtual_memory():
117
+ curr_pct = get_cpu_ram_pct()
118
+ if curr_pct > self.virtual_memory_safe_pct:
119
+ warn(
120
+ f"***** WARNING: {curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used"
121
+ )
122
+
123
+ def get_cpu_ram_pct() -> float:
124
+ # get the percentage of memory used by the system
125
+ return psutil.virtual_memory().percent
126
+
127
+ def get_tensor_id() -> int:
128
+ # create a unique id for each tensor we are managing
129
+ self.tensor_id += 1
130
+ return self.tensor_id
131
+
132
+ def get_num_bytes_tensor(x: torch.Tensor) -> int:
133
+ # get the number of bytes in a tensor, for memory management purposes
134
+ return (
135
+ x.element_size() * x.nelement()
136
+ ) # x.element_size() * x._base_storage().nbytes()
137
+
138
+ # -------- core pack / unpack work -------- #
139
+ def pack_tensor(activation: torch.Tensor) -> int:
140
+ # activations are passed in during forward pass - from here we take over and return a unique id
141
+ if self.is_first_forward_call:
142
+ assert (
143
+ len(self.tracker) == 0
144
+ ), "backward pass should have cleared tracker of all tensors"
145
+
146
+ # set training phase trackers
147
+ self.is_first_forward_call = False
148
+ self.is_first_backward_call = True
149
+
150
+ # query for basic tensor info
151
+ num_bytes = get_num_bytes_tensor(activation)
152
+ tensor_id = get_tensor_id()
153
+
154
+ # only offload hefty bois if they're activations on CUDA (our heuristic
155
+ # for that is to check if they're not params or buffers)!
156
+ if (
157
+ activation.is_cuda
158
+ and num_bytes >= self.min_tensor_size_bytes
159
+ and (
160
+ not isinstance(activation, torch.nn.Parameter)
161
+ and not isinstance(activation, torch.nn.Buffer)
162
+ )
163
+ ):
164
+ if self.use_streams:
165
+ # First, sync back and dereference previously offloaded tensors
166
+ # as the offloading should be done sufficiently long ago.
167
+ for id in [k for k in self.fwd_stash.keys()]:
168
+ if id <= tensor_id - self.max_fwd_stash_size:
169
+ _, ev = self.fwd_stash[id]
170
+ self.s0.wait_event(ev)
171
+ del self.fwd_stash[id]
172
+ else:
173
+ break
174
+
175
+ # Sync in, offload, and add an event to sync back later
176
+ self.s1.wait_stream(self.s0)
177
+
178
+ stream = self.s1 if self.use_streams else self.s0
179
+ with torch.cuda.stream(stream):
180
+ try:
181
+ cpu_tensor = torch.empty_like(
182
+ activation, pin_memory=self.use_pin_memory, device="cpu"
183
+ )
184
+ except NotImplementedError as e:
185
+ if (
186
+ isinstance(activation, NF4Tensor)
187
+ and torchao.__version__ < "0.6.0.dev20240917"
188
+ ):
189
+ raise RuntimeError(
190
+ "Offloading NF4Tensors requires torchao-0.6.0.dev20240917 or later"
191
+ ) from e
192
+ raise e
193
+ cpu_tensor.copy_(activation, non_blocking=True)
194
+ self.tracker[tensor_id] = (
195
+ cpu_tensor,
196
+ True,
197
+ ) # True = (in future) modified
198
+
199
+ if self.use_streams:
200
+ event = self.s1.record_event()
201
+
202
+ # Stash to keep activation alive til s1 is done
203
+ self.fwd_stash[tensor_id] = (activation, event)
204
+ else:
205
+ self.tracker[tensor_id] = (
206
+ activation,
207
+ False,
208
+ ) # False = not modified, tensor is as is
209
+
210
+ return tensor_id
211
+
212
+ def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor:
213
+ # backward pass - we are called with the tensor_id, which
214
+ # we will use to retrieve the saved/offloaded tensor
215
+ if self.is_first_backward_call:
216
+ if self.is_first_forward_pass:
217
+ self.is_first_forward_pass = False
218
+ if self.use_pin_memory:
219
+ verify_sufficient_virtual_memory()
220
+
221
+ self.is_first_backward_call = False
222
+ self.is_first_forward_call = True
223
+
224
+ assert (
225
+ unpack_tensor_id in self.tracker
226
+ ), f"untracked tensor with id {unpack_tensor_id}"
227
+
228
+ maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id]
229
+ if modified:
230
+ gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True)
231
+ maybe_gpu_tensor = gpu_tensor
232
+
233
+ # clear tensor from tracking
234
+ del self.tracker[unpack_tensor_id]
235
+ return maybe_gpu_tensor
236
+
237
+ def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor:
238
+ # backward pass - we are called with the tensor_id, which
239
+ # we will use to retrieve the saved/offloaded tensor
240
+ if self.is_first_backward_call:
241
+ self.curr_graph_id = torch._C._current_graph_task_id()
242
+
243
+ def wait_and_del_remaining_references() -> None:
244
+ for id in [k for k in self.bwd_tensor_stash.keys()]:
245
+ event = self.bwd_ev_stash[id]
246
+ self.s1.wait_event(event)
247
+ del self.bwd_tensor_stash[id]
248
+
249
+ # Register a callback to the end of autograd to clean everything up
250
+ torch.autograd.variable.Variable._execution_engine.queue_callback(
251
+ wait_and_del_remaining_references
252
+ )
253
+
254
+ if self.is_first_forward_pass:
255
+ self.is_first_forward_pass = False
256
+ if self.use_pin_memory:
257
+ verify_sufficient_virtual_memory()
258
+
259
+ self.is_first_backward_call = False
260
+ self.is_first_forward_call = True
261
+
262
+ assert (
263
+ unpack_tensor_id in self.tracker
264
+ ), f"untracked tensor with id {unpack_tensor_id}"
265
+
266
+ maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id]
267
+ if modified:
268
+ # Get data on the current autograd node
269
+ graph_id = torch._C._current_graph_task_id()
270
+ node = torch._C._current_autograd_node()
271
+ prev_node_ids = []
272
+
273
+ # If we're on a new node, mark prev node's tensors to be freed later
274
+ if graph_id == self.curr_graph_id and self.curr_autograd_node != node:
275
+ self.curr_autograd_node = node
276
+ prev_node_ids = [id for id in self.bwd_tensor_stash.keys()]
277
+
278
+ brought_back_from_cpu = True
279
+ if unpack_tensor_id in self.fwd_stash:
280
+ maybe_gpu_tensor = self.fwd_stash[unpack_tensor_id][0]
281
+ brought_back_from_cpu = False
282
+ else:
283
+ # Kick off the process to bring tensors back
284
+ with torch.cuda.stream(self.s1):
285
+ gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True)
286
+ maybe_gpu_tensor = gpu_tensor
287
+
288
+ # Tell comp stream to wait for the info to be loaded before executing
289
+ self.s0.wait_stream(self.s1)
290
+
291
+ # Stash the tensor to keep memory alive until compute stream is complete
292
+ self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor
293
+
294
+ # Note: [Track views of the unpacked]
295
+ # Why do we get the use count of the unpacked tensor here? We want an
296
+ # initial count to compare to later, during the post-hook of the
297
+ # backward node, when we need to decide whether we're allowed to free
298
+ # the tensor yet. In what obscure cases must we delay freeing the
299
+ # tensor (and thus call record_stream)?
300
+ # 1. Any of the outputs of the backward node is a view of the unpacked
301
+ # tensor.
302
+ # 2. In the case that this unpacked tensor will be used in a
303
+ # checkpointed region, if one of the recomputed saved tensors ends
304
+ # up as a view of the unpacked tensor.
305
+ # 3. The user abuses the system somehow and manually relies on the
306
+ # unpacked tensor to exist after the backward node has executed.
307
+ storage_refcount = torch._C._storage_Use_Count(
308
+ maybe_gpu_tensor.untyped_storage()._cdata
309
+ )
310
+
311
+ def hook(outputs, inputs):
312
+ # create events for the current node inputs/outputs if they were streamed in
313
+ if brought_back_from_cpu:
314
+ # See Note: [Track views of the unpacked]
315
+ # IF any of the outputs is a view of the tensor, OR if a view of
316
+ # the tensor has been saved as a part of checkpoint's recompute
317
+ # process, OR the user has abusedly incurred a reference on the
318
+ # unpacked tensor, THEN the tensor might be used later and we
319
+ # cannot presume to delete it after only the current node is
320
+ # done! So we use our frenemy, record_stream, to ensure the
321
+ # Tensor stays unmessed with until it's done getting used in the
322
+ # compute stream (s0 here). Note that the con here is we introduce
323
+ # non-deterministic (thus higher) memory usage, but this case
324
+ # should not happen often.
325
+ unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id]
326
+ if (
327
+ torch._C._storage_Use_Count(
328
+ unpacked_tensor.untyped_storage()._cdata
329
+ )
330
+ > storage_refcount
331
+ ):
332
+ unpacked_tensor.record_stream(self.s0)
333
+ del self.bwd_tensor_stash[unpack_tensor_id]
334
+ else:
335
+ event = self.s0.record_event()
336
+ self.bwd_ev_stash[unpack_tensor_id] = event
337
+
338
+ # if there are still things in the fwd_stash, get rid of them as we're in bwd now
339
+ for id in [k for k in self.fwd_stash.keys()]:
340
+ _, ev = self.fwd_stash[id]
341
+ self.s0.wait_event(ev)
342
+ del self.fwd_stash[id]
343
+
344
+ # wait on prev node's events and del those
345
+ for id in prev_node_ids:
346
+ event = self.bwd_ev_stash[id]
347
+ self.s1.wait_event(event)
348
+ del self.bwd_tensor_stash[id]
349
+
350
+ return outputs
351
+
352
+ node.register_hook(hook)
353
+
354
+ # clear tensor from tracking
355
+ del self.tracker[unpack_tensor_id]
356
+ return maybe_gpu_tensor
357
+
358
+ unpack_tensor = (
359
+ unpack_tensor_with_streams
360
+ if self.use_streams
361
+ else unpack_tensor_single_stream
362
+ )
363
+ super().__init__(pack_tensor, unpack_tensor)
364
+
365
+
366
+ class NoOpManager(saved_tensors_hooks):
367
+ """
368
+ A saved_tensors_hook manager used to disable any other saved_tensors_hook manager
369
+ applied before. This relies on the behavior that only the most recently registered
370
+ saved_tensors_hook will run.
371
+
372
+ One example usage is to opt a local region of code out of activations offloading,
373
+ which is usually applied globally to best track state.
374
+ """
375
+
376
+ def __init__(self) -> None:
377
+ def noop(tensor):
378
+ return tensor
379
+
380
+ super().__init__(noop, noop)
381
+
382
+
383
+ def get_act_offloading_ctx_manager(
384
+ model: nn.Module, enable_activation_offloading: bool
385
+ ) -> Union[OffloadActivations, contextlib.nullcontext]:
386
+ """Returns the activation offloading context manager for the model, which will be
387
+ a null context if enable_activation_offloading is False.
388
+
389
+ If activation offloading is enabled, we return the OffloadActivations context manager.
390
+ If activation offloading is disabled, we return a NoOpManager context manager.
391
+
392
+ Args:
393
+ model (nn.Module): the model to wrap with the activation offloading context manager.
394
+ enable_activation_offloading (bool): whether or not to enable activation offloading
395
+ for the model.
396
+
397
+ Returns:
398
+ contextlib.ContextDecorator: the activation offloading context manager for the model.
399
+
400
+ Raises:
401
+ NotImplementedError: If the model is a multimodal model and activation offloading is enabled.
402
+ """
403
+ if enable_activation_offloading:
404
+ activations_handling_ctx = OffloadActivations()
405
+
406
+ # Below is our hack to disable offloading the last output Linear in every
407
+ # step, as the cost for offloading the activation and then soon after bringing
408
+ # it back is expensive. Moreover, due to heuristics in our streaming API,
409
+ # we actually use more memory if we offload it as it interferes with chunkedCE.
410
+ output_head_detected = False
411
+ noop_ctx = NoOpManager()
412
+
413
+ if hasattr(model, "output"):
414
+ if isinstance(model.output, nn.Module):
415
+ model.output.register_forward_pre_hook(
416
+ lambda *args: noop_ctx.__enter__()
417
+ )
418
+ model.output.register_forward_hook(
419
+ lambda *args: noop_ctx.__exit__(), always_call=True
420
+ )
421
+ print("registering hooks for model.output ============ ")
422
+ output_head_detected = True
423
+ # ================================
424
+ # ! TODO[flame] check if we need to detal with TiedLinear
425
+ # The following code appears in `torchtune`
426
+ # elif isinstance(model.output, TiedLinear):
427
+ # model.output.linear.register_forward_pre_hook(
428
+ # lambda *args: noop_ctx.__enter__()
429
+ # )
430
+ # model.output.linear.register_forward_hook(
431
+ # lambda *args: noop_ctx.__exit__(), always_call=True
432
+ # )
433
+ # output_head_detected = True
434
+
435
+ if not output_head_detected:
436
+ logger.warning(
437
+ "During activation offloading, no output head was detected. "
438
+ "If your model has an output head, it will be offloaded. "
439
+ "This usually greatly slows training, given the large vocabulary size. "
440
+ "To change this behavior, set your output head as model.output and make it "
441
+ "an nn.Module."
442
+ )
443
+
444
+ else:
445
+ activations_handling_ctx = contextlib.nullcontext()
446
+
447
+ return activations_handling_ctx
flame/models/fla.toml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [model]
2
+ config = "fla-hub/hamilton-350M-15B"
3
+ tokenizer_path = "mistralai/Mistral-7B-v0.1"
4
+
5
+ [job]
6
+ dump_folder = "exp"
7
+ print_args = true
8
+
9
+ [training]
10
+ batch_size = 2
11
+ seq_len = 2048
12
+ context_len = 2048
13
+ gradient_accumulation_steps = 1
14
+ steps = 20480
15
+ max_norm = 1.0
16
+ skip_nan_inf = true
17
+ data_parallel_replicate_degree = 1
18
+ data_parallel_shard_degree = -1
19
+ tensor_parallel_degree = 1
20
+ compile = false
21
+ dataset = "SlimPajama-627B"
22
+ dataset_name = "default"
23
+ num_workers = 32
24
+ pin_memory = false
25
+ persistent_workers = false
26
+ prefetch_factor = 2
27
+ seed = 42
28
+ varlen = false
29
+
30
+ [optimizer]
31
+ name = "AdamW"
32
+ eps = 1e-15
33
+ lr = 3e-4
34
+
35
+ [lr_scheduler]
36
+ warmup_steps = 1024
37
+ decay_type = "cosine"
38
+ lr_min = 0.1
39
+
40
+ [checkpoint]
41
+ enable_checkpoint = true
42
+ folder = "checkpoint"
43
+ interval_type = "steps"
44
+ interval = 2048
45
+ model_weights_only = false
46
+ export_dtype = "float32"
47
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
48
+
49
+ [profiling]
50
+ enable_profiling = false
51
+ save_traces_folder = "profile_trace"
52
+ profile_freq = 256
53
+
54
+ [metrics]
55
+ log_freq = 32
56
+ enable_wandb = true
57
+
58
+ [experimental]
59
+ context_parallel_degree = 1
60
+ pipeline_parallel_degree = 1
61
+
62
+ [float8]
63
+ enable_fsdp_float8_all_gather = false
64
+ precompute_float8_dynamic_scale_for_fsdp = false
65
+
66
+ [activation_checkpoint]
67
+ mode = "none"
flame/models/parallelize_fla.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D parallelisms (except pipeline parallelism) and various
8
+ # training techniques (e.g. activation checkpointing and compile) to the Llama model.
9
+
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
16
+ from torch.distributed._composable.replicate import replicate
17
+ from torch.distributed._tensor import Replicate, Shard
18
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper
19
+ from torch.distributed.tensor.parallel import (
20
+ ColwiseParallel,
21
+ PrepareModuleInput,
22
+ PrepareModuleOutput,
23
+ RowwiseParallel,
24
+ SequenceParallel,
25
+ parallelize_module
26
+ )
27
+
28
+ from fla.modules.fused_linear_cross_entropy import LinearLossParallel
29
+ from fla.modules.mlp import SwiGLULinearParallel
30
+ from fla.modules.parallel import PrepareModuleWeight
31
+ from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig
32
+ from torchtitan.distributed.parallel_dims import ParallelDims
33
+ from torchtitan.tools.logging import logger
34
+
35
+
36
+ def parallelize_fla(
37
+ model: nn.Module,
38
+ world_mesh: DeviceMesh,
39
+ parallel_dims: ParallelDims,
40
+ job_config: JobConfig,
41
+ ):
42
+ """
43
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
44
+ parallelism to the model.
45
+
46
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
47
+ the model must fit on GPU or CPU memory.
48
+ """
49
+
50
+ if parallel_dims.tp_enabled:
51
+ if (
52
+ job_config.experimental.enable_async_tensor_parallel
53
+ and not job_config.training.compile
54
+ ):
55
+ raise RuntimeError("Async TP requires --training.compile")
56
+ enable_float8_linear = "float8" in job_config.model.converters
57
+ apply_tp(
58
+ model,
59
+ world_mesh["tp"],
60
+ loss_parallel=parallel_dims.loss_parallel_enabled,
61
+ enable_float8=enable_float8_linear,
62
+ enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
63
+ )
64
+
65
+ if job_config.activation_checkpoint.mode != "none":
66
+ apply_ac(model, job_config.activation_checkpoint)
67
+
68
+ # turn on per-block compile after AC wrapping and before FSDP
69
+ if job_config.training.compile:
70
+ apply_compile(model)
71
+
72
+ if (
73
+ parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
74
+ ): # apply FSDP or HSDP, potentially with Context Parallel
75
+ if parallel_dims.dp_replicate_enabled:
76
+ dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
77
+ else:
78
+ dp_mesh_dim_names = ("dp_shard_cp",)
79
+
80
+ apply_fsdp(
81
+ model,
82
+ world_mesh[tuple(dp_mesh_dim_names)],
83
+ param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
84
+ reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
85
+ pp_enabled=parallel_dims.pp_enabled,
86
+ cpu_offload=job_config.training.enable_cpu_offload,
87
+ reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward,
88
+ )
89
+
90
+ if parallel_dims.dp_replicate_enabled:
91
+ logger.info("Applied HSDP to the model")
92
+ else:
93
+ logger.info("Applied FSDP to the model")
94
+
95
+ if parallel_dims.cp_enabled:
96
+ logger.info("Applied Context Parallel to the model")
97
+
98
+ if job_config.training.enable_cpu_offload:
99
+ logger.info("Applied CPU Offloading to the model")
100
+ elif parallel_dims.dp_replicate_enabled:
101
+ if world_mesh.ndim > 1:
102
+ raise RuntimeError("DDP has not supported > 1D parallelism")
103
+ apply_ddp(
104
+ model,
105
+ world_mesh,
106
+ enable_compile=job_config.training.compile,
107
+ enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
108
+ )
109
+
110
+
111
+ class TPPlan:
112
+ def __init__(
113
+ self,
114
+ model=None,
115
+ loss_parallel=False,
116
+ enable_float8=False,
117
+ ):
118
+ self.model = model
119
+ self.loss_parallel = loss_parallel
120
+ self.enable_float8 = enable_float8
121
+ self.base_model_prefix = getattr(model, "base_model_prefix", "model")
122
+
123
+ # TODO(vkuzo): once float8 configuration supports delayed scaling,
124
+ # add a check here to enforce supported float8 all-gather configurations
125
+ # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
126
+ try:
127
+ from torchao.float8.float8_tensor_parallel import (
128
+ Float8ColwiseParallel,
129
+ Float8RowwiseParallel,
130
+ PrepareFloat8ModuleInput
131
+ )
132
+ except ImportError:
133
+ Float8ColwiseParallel = None
134
+ Float8RowwiseParallel = None
135
+ PrepareFloat8ModuleInput = None
136
+ if self.enable_float8 and Float8ColwiseParallel is not None:
137
+ self.rowwise_parallel = Float8RowwiseParallel
138
+ self.colwise_parallel = Float8ColwiseParallel
139
+ self.prepare_module_input = PrepareFloat8ModuleInput
140
+ self.prepare_module_output = PrepareModuleOutput
141
+ else:
142
+ self.rowwise_parallel = RowwiseParallel
143
+ self.colwise_parallel = ColwiseParallel
144
+ self.prepare_module_input = PrepareModuleInput
145
+ self.prepare_module_output = PrepareModuleOutput
146
+
147
+ @property
148
+ def model_plan(self):
149
+ plans = {
150
+ f"{self.base_model_prefix}.embeddings": RowwiseParallel(
151
+ input_layouts=Replicate(),
152
+ output_layouts=Shard(1),
153
+ ),
154
+ f"{self.base_model_prefix}.norm": SequenceParallel(),
155
+ }
156
+ if self.loss_parallel:
157
+ plans.update(
158
+ {
159
+ "lm_head": ColwiseParallel(
160
+ input_layouts=Shard(1),
161
+ output_layouts=Shard(-1) if self.loss_parallel else Replicate(),
162
+ use_local_output=not self.loss_parallel,
163
+ ),
164
+ }
165
+ )
166
+ else:
167
+ plans.update(
168
+ {
169
+ "lm_head": PrepareModuleWeight(layouts=Replicate()),
170
+ "criterion": LinearLossParallel(),
171
+ }
172
+ )
173
+ return plans
174
+
175
+ @property
176
+ def layer_plan(self):
177
+ return {
178
+ "attn_norm": SequenceParallel(),
179
+ **self.attn_plan,
180
+ "mlp_norm": SequenceParallel(),
181
+ **self.mlp_plan,
182
+ }
183
+
184
+ @property
185
+ def attn_plan(self):
186
+ raise NotImplementedError(
187
+ f"TP plans for token mixing layers of {self.model.config.model_type} not implemented"
188
+ )
189
+
190
+ @property
191
+ def mlp_plan(self):
192
+ return {
193
+ "mlp": self.prepare_module_input(
194
+ input_layouts=(Shard(1),),
195
+ desired_input_layouts=(Replicate(),),
196
+ ),
197
+ "mlp.gate_proj": self.colwise_parallel(),
198
+ "mlp.up_proj": self.colwise_parallel(),
199
+ "mlp.down_proj": self.rowwise_parallel(output_layouts=Shard(1)),
200
+ "mlp.swiglu_linear": SwiGLULinearParallel(output_layouts=Shard(1)),
201
+ }
202
+
203
+
204
+ class TransformerTPPlan(TPPlan):
205
+
206
+ @property
207
+ def attn_plan(self):
208
+ return {
209
+ "attn": self.prepare_module_input(
210
+ input_kwarg_layouts={"hidden_states": Shard(1)},
211
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
212
+ ),
213
+ "attn.q_proj": self.colwise_parallel(),
214
+ "attn.k_proj": self.colwise_parallel(),
215
+ "attn.v_proj": self.colwise_parallel(),
216
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
217
+ }
218
+
219
+
220
+ class GLATPPlan(TPPlan):
221
+
222
+ @property
223
+ def attn_plan(self):
224
+ return {
225
+ "attn": self.prepare_module_input(
226
+ input_kwarg_layouts={"hidden_states": Shard(1)},
227
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
228
+ ),
229
+ "attn.q_proj": self.colwise_parallel(),
230
+ "attn.k_proj": self.colwise_parallel(),
231
+ "attn.v_proj": self.colwise_parallel(),
232
+ "attn.g_proj": self.colwise_parallel(),
233
+ "attn.gk_proj.0": PrepareModuleWeight(layouts=Replicate()),
234
+ "attn.gk_proj.1": self.colwise_parallel(),
235
+ "attn.g_norm": SequenceParallel(sequence_dim=-1),
236
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
237
+ }
238
+
239
+
240
+ TP_PLAN_MAP = {"transformer": TransformerTPPlan, "gla": GLATPPlan}
241
+
242
+
243
+ def apply_tp(
244
+ model: nn.Module,
245
+ tp_mesh: DeviceMesh,
246
+ loss_parallel: bool,
247
+ enable_float8: bool,
248
+ enable_async_tp: bool,
249
+ ):
250
+ """Apply tensor parallelism."""
251
+ # 1. Parallelize the embedding and shard its outputs (which are the first
252
+ # transformer block's inputs)
253
+ # 2. Parallelize the root norm layer over the sequence dim
254
+ # 3. Parallelize the final linear output layer
255
+ tp_plan = TP_PLAN_MAP[model.config.model_type](
256
+ model, loss_parallel=loss_parallel, enable_float8=enable_float8
257
+ )
258
+ parallelize_module(model, tp_mesh, tp_plan.model_plan)
259
+
260
+ blocks = get_blocks(model)
261
+ if blocks is None:
262
+ logger.warning("No block found for tensor parallelism")
263
+ else:
264
+ for _, block in enumerate(blocks):
265
+ parallelize_module(
266
+ module=block,
267
+ device_mesh=tp_mesh,
268
+ parallelize_plan=tp_plan.layer_plan,
269
+ )
270
+
271
+ if enable_async_tp:
272
+ from torch.distributed._symmetric_memory import enable_symm_mem_for_group
273
+
274
+ torch._inductor.config._micro_pipeline_tp = True
275
+ enable_symm_mem_for_group(tp_mesh.get_group().group_name)
276
+
277
+ logger.info(
278
+ f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
279
+ "Tensor Parallelism to the model"
280
+ )
281
+
282
+
283
+ # for selective op activation checkpointing
284
+ _save_list = {
285
+ torch.ops.aten.mm.default,
286
+ torch.ops.aten._scaled_dot_product_efficient_attention.default,
287
+ torch.ops.aten._scaled_dot_product_flash_attention.default,
288
+ torch.ops._c10d_functional.reduce_scatter_tensor.default,
289
+ # for low precision training, it's useful to always save
290
+ # the result of max, since the absolute maximum is
291
+ # used to compute the scaling factor for quantization.
292
+ torch.ops.aten.max.default,
293
+ }
294
+
295
+
296
+ def _apply_ac_to_block(module: nn.Module, ac_config):
297
+ valid_ac_modes = ("full", "selective")
298
+ if ac_config.mode not in valid_ac_modes:
299
+ raise ValueError(
300
+ f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
301
+ )
302
+
303
+ if ac_config.mode == "full":
304
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
305
+
306
+ assert ac_config.mode == "selective", f"{ac_config.mode}"
307
+ use_op_sac = ac_config.selective_ac_option == "op"
308
+ use_layer_sac = ac_config.selective_ac_option.isdigit()
309
+ if not use_op_sac and not use_layer_sac:
310
+ raise ValueError(
311
+ f"Invalid selective AC option: {ac_config.selective_ac_option}. "
312
+ f"Valid options: 'op' or a positive int representing layer frequency"
313
+ )
314
+ if use_op_sac:
315
+ from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
316
+
317
+ def _get_custom_policy(meta):
318
+ def _custom_policy(ctx, func, *args, **kwargs):
319
+ mode = "recompute" if ctx.is_recompute else "forward"
320
+ mm_count_key = f"{mode}_mm_count"
321
+ if func == torch.ops.aten.mm.default:
322
+ meta[mm_count_key] += 1
323
+ # Saves output of all compute ops, except every second mm
324
+ to_save = func in _save_list and not (
325
+ func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
326
+ )
327
+ return (
328
+ CheckpointPolicy.MUST_SAVE
329
+ if to_save
330
+ else CheckpointPolicy.PREFER_RECOMPUTE
331
+ )
332
+
333
+ return _custom_policy
334
+
335
+ def selective_checkpointing_context_fn():
336
+ meta = defaultdict(int)
337
+ return create_selective_checkpoint_contexts(_get_custom_policy(meta))
338
+
339
+ return ptd_checkpoint_wrapper(
340
+ module,
341
+ context_fn=selective_checkpointing_context_fn,
342
+ preserve_rng_state=False,
343
+ )
344
+ elif use_layer_sac:
345
+ # Checkpoint every `ac_freq` of the modules passed to this function
346
+ ac_freq = int(ac_config.selective_ac_option)
347
+ ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
348
+ ptd_checkpoint_wrapper._count += 1
349
+ if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
350
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
351
+ else:
352
+ return module
353
+
354
+
355
+ def apply_ac(model: nn.Module, ac_config):
356
+ """Apply activation checkpointing to the model."""
357
+ blocks = get_blocks(model)
358
+ if blocks is None:
359
+ logger.warning("No block found for activation checkpointing")
360
+ return
361
+
362
+ for layer_id, block in blocks.named_children():
363
+ block = _apply_ac_to_block(block, ac_config)
364
+ blocks.register_module(layer_id, block)
365
+
366
+ logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
367
+
368
+
369
+ def apply_compile(model: nn.Module):
370
+ """
371
+ Apply torch.compile to each block, which makes compilation efficient due to
372
+ repeated structure. Alternatively one can compile the whole model (after applying DP).
373
+ """
374
+
375
+ blocks = get_blocks(model)
376
+ if blocks is None:
377
+ logger.warning("No block found for torch.compile")
378
+ else:
379
+ for layer_id, block in blocks.named_children():
380
+ block = torch.compile(block)
381
+ blocks.register_module(layer_id, block)
382
+ logger.info("Compiling each block with torch.compile")
383
+
384
+ real_model = get_model(model)
385
+
386
+ logger.info("Compiling the embedding, norm, and lm_head layers with torch.compile")
387
+ embeddings_key = get_components_name(real_model, "tok_embeddings")
388
+ if embeddings_key is not None:
389
+ embeddings = torch.compile(getattr(real_model, embeddings_key), fullgraph=True)
390
+ real_model.register_module(embeddings_key, embeddings)
391
+
392
+ norm_key = get_components_name(real_model, "norm")
393
+ if norm_key is not None:
394
+ norm = torch.compile(getattr(real_model, norm_key), fullgraph=True)
395
+ real_model.register_module(norm_key, norm)
396
+
397
+ lm_head_key = get_components_name(model, "lm_head")
398
+ if lm_head_key is not None:
399
+ lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True)
400
+ model.register_module(lm_head_key, lm_head)
401
+
402
+ logger.info("Compiling the entire model with torch.compile")
403
+ model = torch.compile(model)
404
+
405
+
406
+ def apply_fsdp(
407
+ model: nn.Module,
408
+ dp_mesh: DeviceMesh,
409
+ param_dtype: torch.dtype,
410
+ reduce_dtype: torch.dtype,
411
+ pp_enabled: bool,
412
+ cpu_offload: bool = False,
413
+ reshard_after_forward_policy: str = "default",
414
+ ):
415
+ """
416
+ Apply data parallelism (via FSDP2) to the model.
417
+
418
+ Args:
419
+ model (nn.Module): The model to apply data parallelism to.
420
+ dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
421
+ param_dtype (torch.dtype): The data type to use for model parameters.
422
+ reduce_dtype (torch.dtype): The data type to use for reduction operations.
423
+ pp_enabled (bool): Whether pipeline parallelism is enabled.
424
+ cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
425
+ reshard_after_forward_policy (str, optional):
426
+ The policy to use for resharding after forward pass. Defaults to "default".
427
+ Other options: "never", "always".
428
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
429
+ - "always" will enable `reshard_after_forward` for all forward passes.
430
+ - "never" will disable `reshard_after_forward` for all forward passes.
431
+
432
+ """
433
+ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
434
+ fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
435
+ if cpu_offload:
436
+ fsdp_config["offload_policy"] = CPUOffloadPolicy()
437
+
438
+ blocks = get_blocks(model)
439
+ if blocks is None:
440
+ logger.warning("No block found for FSDP")
441
+ else:
442
+ total_blocks = len(blocks)
443
+ for layer_id, block in enumerate(blocks):
444
+ if reshard_after_forward_policy == "always":
445
+ reshard_after_forward = True
446
+ elif reshard_after_forward_policy == "never":
447
+ reshard_after_forward = False
448
+ elif reshard_after_forward_policy == "default":
449
+ if pp_enabled:
450
+ # For PP, do not reshard after forward to avoid per-microbatch
451
+ # all-gathers, which can be expensive and non-overlapped
452
+ reshard_after_forward = False
453
+ else:
454
+ # As an optimization, do not reshard after forward for the last
455
+ # transformer block since FSDP would prefetch it immediately
456
+ reshard_after_forward = int(layer_id) < total_blocks - 1
457
+ else:
458
+ raise ValueError(
459
+ f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
460
+ )
461
+ fully_shard(
462
+ block,
463
+ **fsdp_config,
464
+ reshard_after_forward=reshard_after_forward,
465
+ )
466
+
467
+ fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
468
+
469
+
470
+ def apply_ddp(
471
+ model: nn.Module,
472
+ dp_mesh: DeviceMesh,
473
+ enable_compile: bool,
474
+ enable_compiled_autograd: bool,
475
+ ):
476
+ if enable_compile:
477
+ if enable_compiled_autograd:
478
+ torch._dynamo.config.optimize_ddp = (
479
+ "python_reducer_without_compiled_forward"
480
+ )
481
+ else:
482
+ torch._dynamo.config.optimize_ddp = "ddp_optimizer"
483
+
484
+ replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
485
+
486
+ logger.info("Applied DDP to the model")
487
+
488
+
489
+ def get_model(model):
490
+ base_model_prefix = getattr(model, "base_model_prefix", "model")
491
+ if not hasattr(model, base_model_prefix):
492
+ return None
493
+ model = getattr(model, base_model_prefix)
494
+ return model
495
+
496
+
497
+ def get_blocks(model):
498
+ # TODO[flame]: adapt for network not using 'layers' attribute
499
+ model = get_model(model)
500
+ if not hasattr(model, "layers"):
501
+ logger.warning('no "layers" in model can be found')
502
+ return None
503
+ return model.layers
504
+
505
+
506
+ def get_components_name(model, component_name):
507
+ """
508
+ We try to catch tok_embeddings, norm layers and lm_head layers
509
+ We do not catch the layer names in the blocks, for blocks see `get_blocks`
510
+ We assume the model has the following structure:
511
+ LlamaForCausalLM:
512
+ Model:
513
+ embed_tokens,
514
+ layers,
515
+ norm,
516
+ lm_head
517
+ ***
518
+ so, to search 'tok_embeddings' and 'norm' we need to pass `get_model(model)`
519
+ and for 'lm_head' we need to pass `model`
520
+ ***
521
+ """
522
+
523
+ if component_name == "tok_embeddings":
524
+ if hasattr(model, "tok_embeddings"):
525
+ return "tok_embeddings"
526
+ elif hasattr(model, "embed_tokens"):
527
+ return "embed_tokens"
528
+ elif hasattr(model, "embeddings"):
529
+ return "embeddings"
530
+ else:
531
+ logger.warning("No tok_embeddings found in model")
532
+ return None
533
+
534
+ elif component_name == "norm":
535
+ if hasattr(model, "norm"):
536
+ return "norm"
537
+ elif hasattr(model, "norms"):
538
+ return "norms"
539
+ elif hasattr(model, "layernorm"):
540
+ return "layernorm"
541
+ else:
542
+ logger.warning("No norm found in model")
543
+ return None
544
+
545
+ elif component_name == "lm_head":
546
+ if hasattr(model, "lm_head"):
547
+ return "lm_head"
548
+ else:
549
+ logger.warning("No lm_head found in model")
550
+ return None
flame/models/pipeline_fla.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D pipeline parallelism to the Llama model.
8
+
9
+ import copy
10
+ from typing import Callable, Optional, Union
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed.pipelining import PipelineStage
16
+ from torch.distributed.pipelining.schedules import ScheduleZBVZeroBubble, _PipelineSchedule, get_schedule_class
17
+ from transformers import PretrainedConfig
18
+
19
+ from flame.models.parallelize_fla import get_blocks, get_components_name, get_model
20
+ from torchtitan.config_manager import JobConfig
21
+ from torchtitan.distributed.parallel_dims import ParallelDims
22
+ from torchtitan.distributed.pipeline import build_pipeline_schedule, generate_split_points, stage_ids_this_rank
23
+ from torchtitan.tools.logging import logger
24
+
25
+ DeviceType = Union[int, str, torch.device]
26
+
27
+
28
+ def pipeline_fla(
29
+ model: nn.Module,
30
+ pp_mesh: DeviceMesh,
31
+ parallel_dims: ParallelDims,
32
+ job_config: JobConfig,
33
+ device: DeviceType,
34
+ model_config: PretrainedConfig,
35
+ loss_fn: Callable[..., torch.Tensor],
36
+ ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
37
+ stages, models = pipeline_fla_manual_split(
38
+ model, pp_mesh, parallel_dims, job_config, device, model_config
39
+ )
40
+
41
+ pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
42
+
43
+ # This is used in the train loop to determine whether to pass in the input_ids and labels
44
+ has_first_stage = False
45
+ has_last_stage = False
46
+ for stage in stages:
47
+ if stage.is_first:
48
+ has_first_stage = True
49
+ if stage.is_last:
50
+ has_last_stage = True
51
+
52
+ return pp_schedule, models, has_first_stage, has_last_stage
53
+
54
+
55
+ def pipeline_fla_manual_split(
56
+ whole_model: nn.Module,
57
+ pp_mesh: DeviceMesh,
58
+ parallel_dims: ParallelDims,
59
+ job_config: JobConfig,
60
+ device: DeviceType,
61
+ model_config: PretrainedConfig,
62
+ ) -> tuple[list[PipelineStage], list[nn.Module]]:
63
+ """
64
+ This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
65
+
66
+ It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects.
67
+
68
+ The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
69
+ parallelism.
70
+ """
71
+ pp_rank = pp_mesh.get_local_rank()
72
+ pp_size = pp_mesh.size()
73
+
74
+ splits = (
75
+ job_config.experimental.pipeline_parallel_split_points
76
+ or generate_split_points(
77
+ job_config, parallel_dims.pp, model_config.num_hidden_layers
78
+ )
79
+ )
80
+
81
+ def _build_stage(
82
+ stage_idx: int,
83
+ start_layer: Optional[str],
84
+ stop_layer: Optional[str],
85
+ is_first: bool = False,
86
+ is_last: bool = False,
87
+ ) -> tuple[PipelineStage, nn.Module]:
88
+ model = copy.deepcopy(whole_model)
89
+ if not is_first:
90
+ # we do `model.tok_embeddings = None` here
91
+ real_model = get_model(model)
92
+ tok_embeddings_name = get_components_name(real_model, "tok_embeddings")
93
+ setattr(real_model, tok_embeddings_name, None)
94
+
95
+ drop_layers = start_layer is not None
96
+ # Get module dictionary from get_blocks(model)
97
+ # and Create a list of keys before modifying dictionary
98
+ module_dict = get_blocks(model)._modules # Store reference
99
+ layer_names = list(module_dict.keys())
100
+
101
+ # Iterate over the list of keys instead of `_modules.items()`
102
+ for name in layer_names:
103
+ # Dynamically determine prefix (blocks.* or layers.*)
104
+ prefix = start_layer.split(".")[0] if start_layer else "layers"
105
+ layer_name = f"{prefix}.{name}" # Construct the correct name format
106
+
107
+ # Ensure `drop_layers` activation is based on actual naming
108
+ if layer_name == start_layer:
109
+ drop_layers = False
110
+ if layer_name == stop_layer:
111
+ drop_layers = True
112
+
113
+ # Delete layer if drop_layers is active
114
+ if drop_layers:
115
+ del module_dict[name] # Safe deletion from stored dictionary
116
+
117
+ if not is_last:
118
+ # we do `model.norm = None` and `model.output = None`
119
+ real_model = get_model(model)
120
+ norm_name = get_components_name(real_model, "norm")
121
+ setattr(real_model, norm_name, None)
122
+
123
+ head_name = get_components_name(model, "lm_head")
124
+ setattr(model, head_name, None)
125
+
126
+ stage = PipelineStage(
127
+ model,
128
+ stage_idx,
129
+ num_stages,
130
+ device,
131
+ group=pp_mesh.get_group("pp"),
132
+ )
133
+ return stage, model
134
+
135
+ num_stages = len(splits) + 1
136
+ stage_idx = pp_rank
137
+
138
+ stages = []
139
+ models = []
140
+
141
+ schedule_class = get_schedule_class(
142
+ job_config.experimental.pipeline_parallel_schedule
143
+ )
144
+ style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
145
+
146
+ for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
147
+ start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
148
+ stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
149
+ stage, model_chunk = _build_stage(
150
+ stage_idx,
151
+ start_layer,
152
+ stop_layer,
153
+ is_first=stage_idx == 0,
154
+ is_last=stage_idx == num_stages - 1,
155
+ )
156
+ logger.info(
157
+ f"PP rank {pp_rank} is building stage_idx {stage_idx}"
158
+ f" with start_layer {start_layer}, stop_layer {stop_layer}"
159
+ )
160
+ stages.append(stage)
161
+ models.append(model_chunk)
162
+ return stages, models
flame/tools/__init__.py ADDED
File without changes
flame/tools/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (154 Bytes). View file
 
flame/tools/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (142 Bytes). View file
 
flame/tools/__pycache__/utils.cpython-311.pyc ADDED
Binary file (2.37 kB). View file
 
flame/tools/__pycache__/utils.cpython-312.pyc ADDED
Binary file (2.15 kB). View file
 
flame/tools/utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import nn
8
+ from torchtitan.tools.logging import logger
9
+
10
+
11
+ def get_nparams_and_flops(model: nn.Module, model_config, seq_len: int) -> tuple[int, int]:
12
+ nparams = sum(p.numel() for p in model.parameters())
13
+ nparams_embedding = sum(
14
+ sum(p.numel() for p in m.parameters())
15
+ for m in model.children()
16
+ if isinstance(m, nn.Embedding)
17
+ )
18
+
19
+ if hasattr(model_config, "num_heads"):
20
+ num_heads = model_config.num_heads
21
+ elif hasattr(model_config, "num_attention_heads"):
22
+ num_heads = model_config.num_attention_heads
23
+ else:
24
+ num_heads = 1
25
+ logger.warning("num_heads not found in model_config, defaulting to 1. ")
26
+
27
+ l, h, q, t = (
28
+ model_config.num_hidden_layers,
29
+ num_heads,
30
+ model_config.hidden_size // num_heads,
31
+ seq_len,
32
+ )
33
+ # Reasoning behind the factor of 12 for the self-attention part of the formula:
34
+ # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
35
+ # 2. the flash attention does 1 more matmul recomputation in the backward
36
+ # but recomputation should not be counted in calculating MFU (+0)
37
+ # 3. each matmul performs 1 multiplication and 1 addition (*2)
38
+ # 4. we follow the convention and do not account for sparsity in causal attention
39
+ num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
40
+
41
+ return nparams, num_flops_per_token
flame/train.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import os
9
+ import time
10
+ from datetime import timedelta
11
+
12
+ import fla # noqa
13
+ import torch
14
+ from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss
15
+ from fla.ops.utils import prepare_position_ids
16
+ from torch.distributed.elastic.multiprocessing.errors import record
17
+ from torchtitan.components.checkpoint import CheckpointManager
18
+ from torchtitan.components.ft import FTParallelDims, init_ft_manager
19
+ from torchtitan.components.loss import build_cross_entropy_loss
20
+ from torchtitan.components.lr_scheduler import build_lr_schedulers
21
+ from torchtitan.components.metrics import build_device_memory_monitor, build_metrics_processor, ensure_pp_loss_visible
22
+ from torchtitan.components.optimizer import build_optimizers
23
+ from torchtitan.distributed import ParallelDims
24
+ from torchtitan.distributed import utils as dist_utils
25
+ from torchtitan.protocols.model_converter import build_model_converters
26
+ from torchtitan.protocols.train_spec import TrainSpec, get_train_spec, register_train_spec
27
+ from torchtitan.tools import utils
28
+ from torchtitan.tools.logging import init_logger, logger
29
+ from torchtitan.tools.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
30
+
31
+ import custom_models
32
+ from flame.components.checkpoint import TrainState
33
+ from flame.config_manager import JobConfig
34
+ from flame.data import build_dataloader, build_dataset
35
+ from flame.models.parallelize_fla import parallelize_fla
36
+ from flame.models.pipeline_fla import pipeline_fla
37
+ from flame.tools.utils import get_nparams_and_flops
38
+
39
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
40
+ from fla.models import HamiltonForCausalLM as NewModelForCausalLM, HamiltonConfig as NewConfig
41
+ # from fla.models import GLAForCausalLM as NewModelForCausalLM, GLAConfig as NewConfig
42
+ MODEL_TYPE = NewConfig.model_type
43
+
44
+
45
+ def build_tokenizer(job_config: JobConfig) -> AutoTokenizer:
46
+ return AutoTokenizer.from_pretrained(job_config.model.tokenizer_path)
47
+
48
+
49
+ register_train_spec(
50
+ TrainSpec(
51
+ name="fla",
52
+ cls=AutoModelForCausalLM,
53
+ config=AutoConfig,
54
+ parallelize_fn=parallelize_fla,
55
+ pipelining_fn=pipeline_fla,
56
+ build_optimizers_fn=build_optimizers,
57
+ build_lr_schedulers_fn=build_lr_schedulers,
58
+ build_dataloader_fn=build_dataloader,
59
+ build_tokenizer_fn=build_tokenizer,
60
+ build_loss_fn=build_cross_entropy_loss,
61
+ )
62
+ )
63
+
64
+
65
+ # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
66
+ @record
67
+ def main(job_config: JobConfig):
68
+ logger.info(f"Starting job: {job_config.job.description}")
69
+ logger.info(f"Registering model type: {MODEL_TYPE}")
70
+
71
+ if job_config.experimental.custom_model_path:
72
+ utils.import_module_from_path(job_config.experimental.custom_model_path)
73
+
74
+ # used for colorful printing
75
+ color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color
76
+
77
+ if job_config.job.print_args:
78
+ logger.info(
79
+ f"{color.green}{json.dumps(job_config.to_dict(), indent=2, sort_keys=True)}{color.reset}"
80
+ )
81
+
82
+ # take control of garbage collection to avoid stragglers
83
+ gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)
84
+
85
+ device_module, device_type = utils.device_module, utils.device_type
86
+ device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
87
+ # Device has to be set before creating TorchFT manager.
88
+ device_module.set_device(device)
89
+ ft_manager = init_ft_manager(job_config)
90
+
91
+ # init distributed
92
+ world_size = int(os.environ["WORLD_SIZE"])
93
+ if not ft_manager.enabled:
94
+ parallel_dims = ParallelDims(
95
+ dp_shard=job_config.training.data_parallel_shard_degree,
96
+ dp_replicate=job_config.training.data_parallel_replicate_degree,
97
+ cp=job_config.experimental.context_parallel_degree,
98
+ tp=job_config.training.tensor_parallel_degree,
99
+ pp=job_config.experimental.pipeline_parallel_degree,
100
+ world_size=world_size,
101
+ enable_loss_parallel=not job_config.training.disable_loss_parallel,
102
+ )
103
+ else:
104
+ parallel_dims = FTParallelDims(
105
+ dp_shard=job_config.training.data_parallel_shard_degree,
106
+ dp_replicate=job_config.training.data_parallel_replicate_degree,
107
+ cp=job_config.experimental.context_parallel_degree,
108
+ tp=job_config.training.tensor_parallel_degree,
109
+ pp=job_config.experimental.pipeline_parallel_degree,
110
+ world_size=world_size,
111
+ enable_loss_parallel=not job_config.training.disable_loss_parallel,
112
+ ft_manager=ft_manager,
113
+ )
114
+ dist_utils.init_distributed(job_config)
115
+ # initialize device memory monitor and get peak flops for MFU calculation
116
+ device_memory_monitor = build_device_memory_monitor()
117
+ gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name)
118
+ logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}")
119
+
120
+ # build meshes
121
+ world_mesh = parallel_dims.build_mesh(device_type=device_type)
122
+ if parallel_dims.dp_enabled:
123
+ dp_mesh = world_mesh["dp"]
124
+ dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
125
+ else:
126
+ dp_degree, dp_rank = 1, 0
127
+
128
+ if parallel_dims.pp_enabled:
129
+ raise NotImplementedError(
130
+ "Pipeline parallelism is not supported in this version"
131
+ )
132
+ """
133
+ ! TODO[flame]: We need to fix the pipeline parallelism for flame
134
+ [x] Match the key of models' components with the actual naming
135
+ [ ] Fix the post-init and tie-embedding for pipeline parallelism, HF's transformer automatically
136
+ forces to tie if head is None, we need to handle this case
137
+ [ ]
138
+ """
139
+ pp_mesh = world_mesh["pp"]
140
+
141
+ # Set random seed, and maybe enable deterministic mode (mainly for debugging, expect perf loss)
142
+ dist_utils.set_determinism(
143
+ world_mesh, device, job_config.training.seed, job_config.training.deterministic
144
+ )
145
+ train_spec = get_train_spec(job_config.model.name)
146
+
147
+ logger.info("Loading tokenizer...")
148
+ tokenizer = AutoTokenizer.from_pretrained(
149
+ job_config.model.tokenizer_path,
150
+ trust_remote_code=True,
151
+ model_max_length=int(1e10),
152
+ )
153
+ logger.info(f"{tokenizer}")
154
+ logger.info(
155
+ f"Loading dataset {job_config.training.dataset}"
156
+ f":{job_config.training.dataset_name}"
157
+ if job_config.training.dataset_name is not None
158
+ else ""
159
+ )
160
+ dataset = build_dataset(
161
+ dataset=job_config.training.dataset,
162
+ dataset_name=job_config.training.dataset_name,
163
+ dataset_split=job_config.training.dataset_split,
164
+ data_dir=job_config.training.data_dir,
165
+ data_files=job_config.training.data_files,
166
+ data_probs=job_config.training.data_probs,
167
+ streaming=job_config.training.streaming,
168
+ dp_degree=dp_degree,
169
+ num_workers=job_config.training.num_workers,
170
+ seed=job_config.training.seed,
171
+ )
172
+
173
+ logger.info("Building dataloader...")
174
+ dataloader = build_dataloader(
175
+ dataset=dataset,
176
+ tokenizer=tokenizer,
177
+ rank=dp_rank,
178
+ world_size=dp_degree,
179
+ batch_size=job_config.training.batch_size,
180
+ seq_len=job_config.training.seq_len,
181
+ context_len=job_config.training.context_len,
182
+ varlen=job_config.training.varlen,
183
+ num_workers=job_config.training.num_workers,
184
+ pin_memory=job_config.training.pin_memory,
185
+ persistent_workers=job_config.training.persistent_workers,
186
+ snapshot_every_n_steps=job_config.checkpoint.interval,
187
+ )
188
+
189
+
190
+ logger.info(f"Loading model config from {job_config.model.config}")
191
+ logger.info(f"Registering model type: {MODEL_TYPE}")
192
+ AutoConfig.register(MODEL_TYPE, NewConfig) # important!
193
+ AutoModelForCausalLM.register(NewConfig, NewModelForCausalLM) # important!
194
+ model_config = AutoConfig.from_pretrained(job_config.model.config)
195
+ # set the model configs from training inputs:
196
+ # 1. norm type to decide which norm layer to use
197
+ # 2. disable fused norm if TP is enabled
198
+ # 3. vocab size from tokenizer
199
+ # 4. context_len base on inputs
200
+ if parallel_dims.tp_enabled:
201
+ if model_config.fuse_norm:
202
+ logger.warning(
203
+ f"{color.red}"
204
+ f"Fused norm is not compatible with tensor parallelism. "
205
+ f"Disabling it for now."
206
+ f"{color.reset}"
207
+ )
208
+ model_config.fuse_norm = False
209
+ if parallel_dims.loss_parallel_enabled:
210
+ if model_config.fuse_linear_cross_entropy:
211
+ logger.warning(
212
+ f"{color.red}"
213
+ f"Loss parallel enabled. Disabling fused cross entropy for now."
214
+ f"{color.reset}"
215
+ )
216
+ model_config.fuse_linear_cross_entropy = False
217
+ model_config.vocab_size = max(tokenizer.vocab_size, model_config.vocab_size)
218
+
219
+
220
+
221
+ logger.info(
222
+ f"Building model from the config\n{color.green}{model_config}{color.reset}"
223
+ )
224
+ with torch.device("meta"):
225
+ model = AutoModelForCausalLM.from_config(model_config)
226
+ if (
227
+ getattr(model_config, "fuse_linear_cross_entropy", False)
228
+ and FusedLinearCrossEntropyLoss is not None
229
+ ):
230
+ model.criterion = FusedLinearCrossEntropyLoss(
231
+ num_chunks=8 // parallel_dims.tp
232
+ )
233
+ # defer weight initialization until after parallelisms are applied
234
+ model.apply(lambda m: setattr(m, "_is_hf_initialized", False))
235
+ logger.info(f"{color.blue}\n{model}{color.reset}\n")
236
+
237
+ logger.info("Applying model converters...")
238
+
239
+ # Build the collection of model converters. No-op if `model.converters` empty
240
+ model_converters = build_model_converters(job_config, parallel_dims)
241
+ model_converters.convert(model)
242
+
243
+ # calculate model size and flops per token
244
+ model_param_count, num_flops_per_token = get_nparams_and_flops(
245
+ model, model_config, job_config.training.context_len
246
+ )
247
+
248
+ # move sharded model to CPU/GPU and initialize weights via DTensor
249
+ if job_config.checkpoint.create_seed_checkpoint:
250
+ init_device = "cpu"
251
+ elif job_config.training.enable_cpu_offload:
252
+ init_device = "cpu"
253
+ else:
254
+ init_device = device_type
255
+
256
+ # apply parallelisms and initialization
257
+ if parallel_dims.pp_enabled:
258
+ # apply PT-D Pipeline Parallel
259
+ (
260
+ pp_schedule,
261
+ model_parts,
262
+ has_first_stage,
263
+ has_last_stage,
264
+ ) = train_spec.pipelining_fn(
265
+ model,
266
+ pp_mesh,
267
+ parallel_dims,
268
+ job_config,
269
+ device,
270
+ model_config,
271
+ train_spec.loss_fn,
272
+ )
273
+ # when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
274
+ del model
275
+
276
+ # For PP with looped schedules, each item in model_parts is one stage-model-chunk.
277
+ # We need to iterate through model_parts to apply SPMD parallelisms, compilation,
278
+ # optimizer, and checkpointing
279
+ for m in model_parts:
280
+ # apply SPMD-style PT-D techniques
281
+ train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config)
282
+ m.to_empty(device=init_device)
283
+ with torch.no_grad():
284
+ m.post_init()
285
+ m.train()
286
+
287
+ # confirm that user will be able to view loss metrics on the console
288
+ ensure_pp_loss_visible(parallel_dims, job_config, color)
289
+ else:
290
+ # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
291
+ train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
292
+ model.to_empty(device=init_device)
293
+ with torch.no_grad():
294
+ model.post_init()
295
+ model.train()
296
+
297
+ model_parts = [model]
298
+
299
+ device_mem_stats = device_memory_monitor.get_peak_stats()
300
+ logger.info(
301
+ f"{device_type.upper()} memory usage for model: "
302
+ f"{device_mem_stats.max_reserved_gib:.2f}GiB"
303
+ f"({device_mem_stats.max_reserved_pct:.2f}%)"
304
+ )
305
+
306
+ # build optimizer after applying parallelisms to the model
307
+ optimizers = train_spec.build_optimizers_fn(model_parts, job_config, ft_manager)
308
+ lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config)
309
+ # Post optimizer step model converters hook.
310
+ # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
311
+ # where it issues a single all-reduce for all parameters at once for better performance
312
+ optimizers.register_step_post_hook(
313
+ lambda *args, **kwargs: model_converters.post_optimizer_hook(model_parts)
314
+ )
315
+
316
+ train_state = TrainState()
317
+
318
+ # load initial checkpoint
319
+ checkpoint = CheckpointManager(
320
+ dataloader=dataloader,
321
+ model_parts=model_parts,
322
+ optimizers=optimizers,
323
+ lr_schedulers=lr_schedulers,
324
+ states={"train_state": train_state},
325
+ job_config=job_config,
326
+ ft_manager=ft_manager,
327
+ )
328
+
329
+ if job_config.checkpoint.create_seed_checkpoint:
330
+ assert world_size == 1, (
331
+ "Must create seed checkpoint using a single device, to disable sharding"
332
+ )
333
+ assert job_config.checkpoint.enable_checkpoint, (
334
+ "Must enable checkpointing when creating a seed checkpoint"
335
+ )
336
+ checkpoint.save(curr_step=0, force=True)
337
+ logger.info("Created seed checkpoint")
338
+ return
339
+
340
+ checkpoint.load(step=job_config.checkpoint.load_step)
341
+ metric_logger = build_metrics_processor(job_config, parallel_dims)
342
+ # Set dependent attributes for metric_logger
343
+ metric_logger.num_flops_per_token = num_flops_per_token
344
+ metric_logger.optimizers = optimizers # Pass optimizers if needed by logger logic
345
+ metric_logger.lr_schedulers = (
346
+ lr_schedulers # Pass schedulers if needed by logger logic
347
+ )
348
+
349
+ # plot losses loaded from checkpoint (if any) to TensorBoard
350
+ # NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
351
+ # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq
352
+ if train_state.step > 0 and len(metric_logger.data_loading_times) > 0:
353
+ for idx, step in enumerate(train_state.log_steps):
354
+ metric_logger.log(
355
+ step,
356
+ global_avg_loss=train_state.global_avg_losses[idx],
357
+ global_max_loss=train_state.global_max_losses[idx],
358
+ )
359
+
360
+ data_iterator = iter(dataloader)
361
+
362
+ train_context = dist_utils.get_train_context(
363
+ parallel_dims.loss_parallel_enabled,
364
+ job_config.experimental.enable_compiled_autograd,
365
+ )
366
+ maybe_enable_amp = dist_utils.maybe_enable_amp(
367
+ parallel_dims,
368
+ job_config.training.mixed_precision_param,
369
+ device_type,
370
+ )
371
+
372
+ # variables used to keep info for metrics logging
373
+ device_memory_monitor.reset_peak_stats()
374
+
375
+ global_batch_size = (
376
+ job_config.training.batch_size
377
+ * dp_degree
378
+ * job_config.training.gradient_accumulation_steps
379
+ )
380
+ num_tokens_per_step = global_batch_size * job_config.training.seq_len
381
+ # train loop
382
+ logger.info(f"{color.red}***** Running training *****{color.reset}")
383
+ logger.info(f"{color.green} Training starts at step {train_state.step + 1}")
384
+ logger.info(
385
+ f"{color.green} Number of tokens per sequence = {job_config.training.seq_len:,}"
386
+ )
387
+ logger.info(
388
+ f"{color.green} Gradient Accumulation steps = {job_config.training.gradient_accumulation_steps}"
389
+ )
390
+ logger.info(
391
+ f"{color.green} Instantaneous batch size (per device) = {job_config.training.batch_size:,}"
392
+ )
393
+ logger.info(
394
+ f"{color.green} Global batch size (w. parallel, distributed & accumulation) = {global_batch_size:,}"
395
+ f" ({num_tokens_per_step:,} tokens)"
396
+ )
397
+ logger.info(
398
+ f"{color.green} Total optimization steps = {job_config.training.steps:,} "
399
+ f"({job_config.training.steps * num_tokens_per_step:,} tokens)"
400
+ )
401
+ logger.info(
402
+ f"{color.green} Warmup steps = {job_config.lr_scheduler.warmup_steps:,}"
403
+ f" ({job_config.lr_scheduler.warmup_steps * num_tokens_per_step:,} tokens)"
404
+ )
405
+ logger.info(
406
+ f"{color.green} Number of parameters = {model_param_count:,} {color.reset}"
407
+ )
408
+
409
+ with (
410
+ maybe_enable_profiling(
411
+ job_config, global_step=train_state.step
412
+ ) as torch_profiler,
413
+ maybe_enable_memory_snapshot(
414
+ job_config, global_step=train_state.step
415
+ ) as memory_profiler,
416
+ ):
417
+ while train_state.step < job_config.training.steps:
418
+ train_state.step += 1
419
+ gc_handler.run(train_state.step)
420
+
421
+ optimizers.zero_grad()
422
+
423
+ losses = []
424
+ # do gradient accumulation if enabled
425
+ for _ in range(job_config.training.gradient_accumulation_steps):
426
+ # get batch
427
+ data_load_start = time.perf_counter()
428
+ batch = next(data_iterator)
429
+ input_ids, labels = batch["input_ids"], batch["labels"]
430
+
431
+ # Update metrics processor state before forward/backward
432
+ metric_logger.ntokens_since_last_log += labels.numel()
433
+ metric_logger.data_loading_times.append(
434
+ time.perf_counter() - data_load_start
435
+ )
436
+
437
+ input_ids = input_ids.to(device_type)
438
+
439
+ """
440
+ TODO[flame]: We need to carefully handle the position_ids for TP/CP
441
+ Depending on the Models'PE, the position_ids might be different.
442
+
443
+ e.g. for TP
444
+ For RoPE, all ranks have the same position_ids. [FOR HF model]
445
+ For sinusoidal, each rank has the coresponding chunked position_ids. [FOR HF model]
446
+
447
+ e.g. for CP, [optional_context_parallel_ctx shoudl automatically distbute the position_ids]
448
+ Each rank has the coresponding chunked position_ids. [FOR All model]
449
+
450
+ """
451
+ labels = labels.to(device_type)
452
+ cu_seqlens = (
453
+ batch["cu_seqlens"].to(device_type)
454
+ if "cu_seqlens" in batch
455
+ else None
456
+ )
457
+ if cu_seqlens is not None:
458
+ position_ids = prepare_position_ids(cu_seqlens).to(torch.int32)
459
+ else:
460
+ position_ids = (
461
+ torch.arange(0, input_ids.shape[1], device=device_type)
462
+ .repeat(input_ids.shape[0], 1)
463
+ .to(torch.int32)
464
+ )
465
+ # apply context parallelism if cp is enabled
466
+ # ensure CP handles the separate freqs_cis buffer for each pp stage
467
+ optional_context_parallel_ctx = (
468
+ dist_utils.create_context_parallel_ctx(
469
+ cp_mesh=world_mesh["cp"],
470
+ cp_buffers=[input_ids, labels, position_ids],
471
+ cp_seq_dims=[1, 1, 1],
472
+ cp_no_restore_buffers={input_ids, labels, position_ids},
473
+ cp_rotate_method=job_config.experimental.context_parallel_rotate_method,
474
+ )
475
+ if parallel_dims.cp_enabled
476
+ else None
477
+ )
478
+
479
+ # #! TODO[flame], we should distribute the position_ids as well with CP
480
+ if parallel_dims.pp_enabled:
481
+ raise NotImplementedError(
482
+ "Pipeline parallelism is not supported in this version"
483
+ )
484
+ # Pipeline Parallel forward / backward inside step() call
485
+ with train_context(optional_context_parallel_ctx):
486
+ targets, losses = (
487
+ (labels, []) if has_last_stage else (None, None)
488
+ )
489
+
490
+ if has_first_stage:
491
+ pp_schedule.step(input_ids, target=targets, losses=losses)
492
+ else:
493
+ pp_schedule.step(target=targets, losses=losses)
494
+
495
+ # accumulate losses across pipeline microbatches
496
+ # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
497
+ loss = (
498
+ torch.mean(torch.stack(losses)).to(device)
499
+ if has_last_stage
500
+ else torch.tensor([-1.0], device=device)
501
+ )
502
+ else:
503
+ # Non-PP forward / backward
504
+ with train_context(optional_context_parallel_ctx):
505
+ with maybe_enable_amp:
506
+ output = model(
507
+ input_ids=input_ids,
508
+ labels=labels,
509
+ position_ids=position_ids,
510
+ cu_seqlens=cu_seqlens,
511
+ )
512
+ loss = (
513
+ output.loss
514
+ / job_config.training.gradient_accumulation_steps
515
+ )
516
+ loss.backward()
517
+ # print('--------------------------')
518
+
519
+ losses.append(loss)
520
+ del batch
521
+ loss = sum(losses)
522
+
523
+ # clip gradients
524
+ grad_norm = dist_utils.clip_grad_norm_(
525
+ [p for m in model_parts for p in m.parameters()],
526
+ job_config.training.max_norm,
527
+ foreach=True,
528
+ pp_mesh=pp_mesh if parallel_dims.pp_enabled else None,
529
+ )
530
+
531
+ # optimizer step
532
+ checkpoint.maybe_wait_for_staging()
533
+ if job_config.training.skip_nan_inf and (
534
+ grad_norm.isnan() or grad_norm.isinf()
535
+ ):
536
+ logger.warning(
537
+ f"Skipping optimizer step - detected invalid gradient norm: {grad_norm:.4f}"
538
+ )
539
+ optimizers.zero_grad()
540
+ train_state.skipped_step += 1
541
+ else:
542
+ optimizers.step()
543
+ lr_schedulers.step()
544
+
545
+ # log metrics - Use MetricsProcessor
546
+ if metric_logger.should_log(train_state.step):
547
+ if (
548
+ parallel_dims.dp_replicate_enabled
549
+ or parallel_dims.dp_shard_enabled
550
+ or parallel_dims.cp_enabled
551
+ ):
552
+ loss = loss.detach()
553
+ # Use dist_mean/max on the accumulated loss for the step
554
+ global_avg_loss, global_max_loss = (
555
+ dist_utils.dist_mean(
556
+ loss,
557
+ world_mesh["dp_cp"],
558
+ ),
559
+ dist_utils.dist_max(
560
+ loss,
561
+ world_mesh["dp_cp"],
562
+ ),
563
+ )
564
+ else:
565
+ # Scale back the loss before logging
566
+ global_avg_loss = global_max_loss = loss.item()
567
+
568
+ # Update train state tokens and elapsed time
569
+ time_now = time.perf_counter()
570
+ time_delta = (
571
+ time_now - metric_logger.time_last_log
572
+ ) # Use metric_logger's time
573
+ train_state.token += (
574
+ metric_logger.ntokens_since_last_log # Use tokens tracked by metric_logger
575
+ * parallel_dims.world_size
576
+ / parallel_dims.non_data_parallel_size
577
+ )
578
+ train_state.elapsed += timedelta(seconds=time_delta)
579
+ train_state.log_steps.append(train_state.step)
580
+ train_state.global_avg_losses.append(global_avg_loss)
581
+ train_state.global_max_losses.append(global_max_loss)
582
+
583
+ # Log using the metric processor
584
+ last_lr = lr_schedulers.schedulers[0].get_last_lr()[0]
585
+ eta = (
586
+ train_state.elapsed
587
+ * (job_config.training.steps - train_state.step)
588
+ / train_state.step
589
+ )
590
+ metric_logger.log(
591
+ train_state.step,
592
+ global_avg_loss,
593
+ global_max_loss,
594
+ extra_metrics={
595
+ "optimizer/lr": last_lr,
596
+ "optimizer/grad_norm": grad_norm.item(),
597
+ "optimizer/skipped_step": train_state.skipped_step,
598
+ },
599
+ )
600
+
601
+ logger.info(
602
+ f"{color.blue}lr: {last_lr:.4e} gnorm: {grad_norm:5.2f} "
603
+ f"{color.magenta}[{str(train_state.elapsed).split('.')[0]:>8}<{str(eta).split('.')[0]:>8}]{color.reset}"
604
+ )
605
+
606
+ checkpoint.save(
607
+ train_state.step, force=(train_state.step == job_config.training.steps)
608
+ )
609
+
610
+ # signal the profiler that the next profiling step has started
611
+ if torch_profiler:
612
+ torch_profiler.step()
613
+ if memory_profiler:
614
+ memory_profiler.step()
615
+
616
+ # reduce timeout after first train step for faster signal
617
+ # (assuming lazy init and compilation are finished)
618
+ if train_state.step == 1:
619
+ dist_utils.set_pg_timeouts(
620
+ timeout=timedelta(seconds=job_config.comm.train_timeout_seconds),
621
+ world_mesh=world_mesh,
622
+ )
623
+
624
+ if torch.distributed.get_rank() == 0:
625
+ logger.info("Sleeping 2 seconds for other ranks to complete")
626
+ time.sleep(2)
627
+
628
+ metric_logger.close()
629
+ logger.info("Training completed")
630
+
631
+
632
+ if __name__ == "__main__":
633
+ init_logger()
634
+ config = JobConfig()
635
+ config.parse_args()
636
+ main(config)
637
+ torch.distributed.destroy_process_group()
flame/train2.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import os
9
+ import time
10
+ from datetime import timedelta
11
+
12
+ import fla # noqa
13
+ import torch
14
+ from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss
15
+ from fla.ops.utils import prepare_position_ids
16
+ from torch.distributed.elastic.multiprocessing.errors import record
17
+ from torchtitan.components.checkpoint import CheckpointManager
18
+ from torchtitan.components.ft import FTParallelDims, init_ft_manager
19
+ from torchtitan.components.loss import build_cross_entropy_loss
20
+ from torchtitan.components.lr_scheduler import build_lr_schedulers
21
+ from torchtitan.components.metrics import build_device_memory_monitor, build_metrics_processor, ensure_pp_loss_visible
22
+ from torchtitan.components.optimizer import build_optimizers
23
+ from torchtitan.distributed import ParallelDims
24
+ from torchtitan.distributed import utils as dist_utils
25
+ from torchtitan.protocols.model_converter import build_model_converters
26
+ from torchtitan.protocols.train_spec import TrainSpec, get_train_spec, register_train_spec
27
+ from torchtitan.tools import utils
28
+ from torchtitan.tools.logging import init_logger, logger
29
+ from torchtitan.tools.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
30
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
31
+
32
+ import custom_models
33
+ from flame.components.checkpoint import TrainState
34
+ from flame.config_manager import JobConfig
35
+ from flame.data import build_dataloader, build_dataset
36
+ from flame.models.parallelize_fla import parallelize_fla
37
+ from flame.models.pipeline_fla import pipeline_fla
38
+ from flame.tools.utils import get_nparams_and_flops
39
+
40
+
41
+ def build_tokenizer(job_config: JobConfig) -> AutoTokenizer:
42
+ return AutoTokenizer.from_pretrained(job_config.model.tokenizer_path)
43
+
44
+
45
+ register_train_spec(
46
+ TrainSpec(
47
+ name="fla",
48
+ cls=AutoModelForCausalLM,
49
+ config=AutoConfig,
50
+ parallelize_fn=parallelize_fla,
51
+ pipelining_fn=pipeline_fla,
52
+ build_optimizers_fn=build_optimizers,
53
+ build_lr_schedulers_fn=build_lr_schedulers,
54
+ build_dataloader_fn=build_dataloader,
55
+ build_tokenizer_fn=build_tokenizer,
56
+ build_loss_fn=build_cross_entropy_loss,
57
+ )
58
+ )
59
+
60
+
61
+ # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
62
+ @record
63
+ def main(job_config: JobConfig):
64
+ logger.info(f"Starting job: {job_config.job.description}")
65
+
66
+ if job_config.experimental.custom_model_path:
67
+ utils.import_module_from_path(job_config.experimental.custom_model_path)
68
+
69
+ # used for colorful printing
70
+ color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color
71
+
72
+ if job_config.job.print_args:
73
+ logger.info(
74
+ f"{color.green}{json.dumps(job_config.to_dict(), indent=2, sort_keys=True)}{color.reset}"
75
+ )
76
+
77
+ # take control of garbage collection to avoid stragglers
78
+ gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)
79
+
80
+ device_module, device_type = utils.device_module, utils.device_type
81
+ device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
82
+ # Device has to be set before creating TorchFT manager.
83
+ device_module.set_device(device)
84
+ ft_manager = init_ft_manager(job_config)
85
+
86
+ # init distributed
87
+ world_size = int(os.environ["WORLD_SIZE"])
88
+ if not ft_manager.enabled:
89
+ parallel_dims = ParallelDims(
90
+ dp_shard=job_config.training.data_parallel_shard_degree,
91
+ dp_replicate=job_config.training.data_parallel_replicate_degree,
92
+ cp=job_config.experimental.context_parallel_degree,
93
+ tp=job_config.training.tensor_parallel_degree,
94
+ pp=job_config.experimental.pipeline_parallel_degree,
95
+ world_size=world_size,
96
+ enable_loss_parallel=not job_config.training.disable_loss_parallel,
97
+ )
98
+ else:
99
+ parallel_dims = FTParallelDims(
100
+ dp_shard=job_config.training.data_parallel_shard_degree,
101
+ dp_replicate=job_config.training.data_parallel_replicate_degree,
102
+ cp=job_config.experimental.context_parallel_degree,
103
+ tp=job_config.training.tensor_parallel_degree,
104
+ pp=job_config.experimental.pipeline_parallel_degree,
105
+ world_size=world_size,
106
+ enable_loss_parallel=not job_config.training.disable_loss_parallel,
107
+ ft_manager=ft_manager,
108
+ )
109
+ dist_utils.init_distributed(job_config)
110
+ # initialize device memory monitor and get peak flops for MFU calculation
111
+ device_memory_monitor = build_device_memory_monitor()
112
+ gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name)
113
+ logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}")
114
+
115
+ # build meshes
116
+ world_mesh = parallel_dims.build_mesh(device_type=device_type)
117
+ if parallel_dims.dp_enabled:
118
+ dp_mesh = world_mesh["dp"]
119
+ dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
120
+ else:
121
+ dp_degree, dp_rank = 1, 0
122
+
123
+ if parallel_dims.pp_enabled:
124
+ raise NotImplementedError(
125
+ "Pipeline parallelism is not supported in this version"
126
+ )
127
+ """
128
+ ! TODO[flame]: We need to fix the pipeline parallelism for flame
129
+ [x] Match the key of models' components with the actual naming
130
+ [ ] Fix the post-init and tie-embedding for pipeline parallelism, HF's transformer automatically
131
+ forces to tie if head is None, we need to handle this case
132
+ [ ]
133
+ """
134
+ pp_mesh = world_mesh["pp"]
135
+
136
+ # Set random seed, and maybe enable deterministic mode (mainly for debugging, expect perf loss)
137
+ dist_utils.set_determinism(
138
+ world_mesh, device, job_config.training.seed, job_config.training.deterministic
139
+ )
140
+ train_spec = get_train_spec(job_config.model.name)
141
+
142
+ logger.info("Loading tokenizer...")
143
+ tokenizer = AutoTokenizer.from_pretrained(
144
+ job_config.model.tokenizer_path,
145
+ trust_remote_code=True,
146
+ model_max_length=int(1e10),
147
+ )
148
+ logger.info(f"{tokenizer}")
149
+ logger.info(
150
+ f"Loading dataset {job_config.training.dataset}"
151
+ f":{job_config.training.dataset_name}"
152
+ if job_config.training.dataset_name is not None
153
+ else ""
154
+ )
155
+ dataset = build_dataset(
156
+ dataset=job_config.training.dataset,
157
+ dataset_name=job_config.training.dataset_name,
158
+ dataset_split=job_config.training.dataset_split,
159
+ data_dir=job_config.training.data_dir,
160
+ data_files=job_config.training.data_files,
161
+ data_probs=job_config.training.data_probs,
162
+ streaming=job_config.training.streaming,
163
+ dp_degree=dp_degree,
164
+ num_workers=job_config.training.num_workers,
165
+ seed=job_config.training.seed,
166
+ )
167
+
168
+ logger.info("Building dataloader...")
169
+ dataloader = build_dataloader(
170
+ dataset=dataset,
171
+ tokenizer=tokenizer,
172
+ rank=dp_rank,
173
+ world_size=dp_degree,
174
+ batch_size=job_config.training.batch_size,
175
+ seq_len=job_config.training.seq_len,
176
+ context_len=job_config.training.context_len,
177
+ varlen=job_config.training.varlen,
178
+ num_workers=job_config.training.num_workers,
179
+ pin_memory=job_config.training.pin_memory,
180
+ persistent_workers=job_config.training.persistent_workers,
181
+ snapshot_every_n_steps=job_config.checkpoint.interval,
182
+ )
183
+
184
+
185
+ logger.info(f"Loading model config from {job_config.model.config}")
186
+ model_config = AutoConfig.from_pretrained(job_config.model.config)
187
+ # set the model configs from training inputs:
188
+ # 1. norm type to decide which norm layer to use
189
+ # 2. disable fused norm if TP is enabled
190
+ # 3. vocab size from tokenizer
191
+ # 4. context_len base on inputs
192
+ if parallel_dims.tp_enabled:
193
+ if model_config.fuse_norm:
194
+ logger.warning(
195
+ f"{color.red}"
196
+ f"Fused norm is not compatible with tensor parallelism. "
197
+ f"Disabling it for now."
198
+ f"{color.reset}"
199
+ )
200
+ model_config.fuse_norm = False
201
+ if parallel_dims.loss_parallel_enabled:
202
+ if model_config.fuse_linear_cross_entropy:
203
+ logger.warning(
204
+ f"{color.red}"
205
+ f"Loss parallel enabled. Disabling fused cross entropy for now."
206
+ f"{color.reset}"
207
+ )
208
+ model_config.fuse_linear_cross_entropy = False
209
+ model_config.vocab_size = max(tokenizer.vocab_size, model_config.vocab_size)
210
+
211
+
212
+
213
+ logger.info(
214
+ f"Building model from the config\n{color.green}{model_config}{color.reset}"
215
+ )
216
+ with torch.device("meta"):
217
+ model = AutoModelForCausalLM.from_config(model_config)
218
+ if (
219
+ getattr(model_config, "fuse_linear_cross_entropy", False)
220
+ and FusedLinearCrossEntropyLoss is not None
221
+ ):
222
+ model.criterion = FusedLinearCrossEntropyLoss(
223
+ num_chunks=8 // parallel_dims.tp
224
+ )
225
+ # defer weight initialization until after parallelisms are applied
226
+ model.apply(lambda m: setattr(m, "_is_hf_initialized", False))
227
+ logger.info(f"{color.blue}\n{model}{color.reset}\n")
228
+
229
+ # Build the collection of model converters. No-op if `model.converters` empty
230
+ model_converters = build_model_converters(job_config, parallel_dims)
231
+ model_converters.convert(model)
232
+
233
+ # calculate model size and flops per token
234
+ model_param_count, num_flops_per_token = get_nparams_and_flops(
235
+ model, model_config, job_config.training.context_len
236
+ )
237
+
238
+ # move sharded model to CPU/GPU and initialize weights via DTensor
239
+ if job_config.checkpoint.create_seed_checkpoint:
240
+ init_device = "cpu"
241
+ elif job_config.training.enable_cpu_offload:
242
+ init_device = "cpu"
243
+ else:
244
+ init_device = device_type
245
+
246
+ # apply parallelisms and initialization
247
+ if parallel_dims.pp_enabled:
248
+ # apply PT-D Pipeline Parallel
249
+ (
250
+ pp_schedule,
251
+ model_parts,
252
+ has_first_stage,
253
+ has_last_stage,
254
+ ) = train_spec.pipelining_fn(
255
+ model,
256
+ pp_mesh,
257
+ parallel_dims,
258
+ job_config,
259
+ device,
260
+ model_config,
261
+ train_spec.loss_fn,
262
+ )
263
+ # when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
264
+ del model
265
+
266
+ # For PP with looped schedules, each item in model_parts is one stage-model-chunk.
267
+ # We need to iterate through model_parts to apply SPMD parallelisms, compilation,
268
+ # optimizer, and checkpointing
269
+ for m in model_parts:
270
+ # apply SPMD-style PT-D techniques
271
+ train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config)
272
+ m.to_empty(device=init_device)
273
+ with torch.no_grad():
274
+ m.post_init()
275
+ m.train()
276
+
277
+ # confirm that user will be able to view loss metrics on the console
278
+ ensure_pp_loss_visible(parallel_dims, job_config, color)
279
+ else:
280
+ # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
281
+ train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
282
+ model.to_empty(device=init_device)
283
+ with torch.no_grad():
284
+ model.post_init()
285
+ model.train()
286
+
287
+ model_parts = [model]
288
+
289
+ device_mem_stats = device_memory_monitor.get_peak_stats()
290
+ logger.info(
291
+ f"{device_type.upper()} memory usage for model: "
292
+ f"{device_mem_stats.max_reserved_gib:.2f}GiB"
293
+ f"({device_mem_stats.max_reserved_pct:.2f}%)"
294
+ )
295
+
296
+ # build optimizer after applying parallelisms to the model
297
+ optimizers = train_spec.build_optimizers_fn(model_parts, job_config, ft_manager)
298
+ lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config)
299
+ # Post optimizer step model converters hook.
300
+ # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
301
+ # where it issues a single all-reduce for all parameters at once for better performance
302
+ optimizers.register_step_post_hook(
303
+ lambda *args, **kwargs: model_converters.post_optimizer_hook(model_parts)
304
+ )
305
+
306
+ train_state = TrainState()
307
+
308
+ # load initial checkpoint
309
+ checkpoint = CheckpointManager(
310
+ dataloader=dataloader,
311
+ model_parts=model_parts,
312
+ optimizers=optimizers,
313
+ lr_schedulers=lr_schedulers,
314
+ states={"train_state": train_state},
315
+ job_config=job_config,
316
+ ft_manager=ft_manager,
317
+ )
318
+
319
+ if job_config.checkpoint.create_seed_checkpoint:
320
+ assert world_size == 1, (
321
+ "Must create seed checkpoint using a single device, to disable sharding"
322
+ )
323
+ assert job_config.checkpoint.enable_checkpoint, (
324
+ "Must enable checkpointing when creating a seed checkpoint"
325
+ )
326
+ checkpoint.save(curr_step=0, force=True)
327
+ logger.info("Created seed checkpoint")
328
+ return
329
+
330
+ checkpoint.load(step=job_config.checkpoint.load_step)
331
+ metric_logger = build_metrics_processor(job_config, parallel_dims)
332
+ # Set dependent attributes for metric_logger
333
+ metric_logger.num_flops_per_token = num_flops_per_token
334
+ metric_logger.optimizers = optimizers # Pass optimizers if needed by logger logic
335
+ metric_logger.lr_schedulers = (
336
+ lr_schedulers # Pass schedulers if needed by logger logic
337
+ )
338
+
339
+ # plot losses loaded from checkpoint (if any) to TensorBoard
340
+ # NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
341
+ # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq
342
+ if train_state.step > 0 and len(metric_logger.data_loading_times) > 0:
343
+ for idx, step in enumerate(train_state.log_steps):
344
+ metric_logger.log(
345
+ step,
346
+ global_avg_loss=train_state.global_avg_losses[idx],
347
+ global_max_loss=train_state.global_max_losses[idx],
348
+ )
349
+
350
+ data_iterator = iter(dataloader)
351
+
352
+ train_context = dist_utils.get_train_context(
353
+ parallel_dims.loss_parallel_enabled,
354
+ job_config.experimental.enable_compiled_autograd,
355
+ )
356
+ maybe_enable_amp = dist_utils.maybe_enable_amp(
357
+ parallel_dims,
358
+ job_config.training.mixed_precision_param,
359
+ device_type,
360
+ )
361
+
362
+ # variables used to keep info for metrics logging
363
+ device_memory_monitor.reset_peak_stats()
364
+
365
+ global_batch_size = (
366
+ job_config.training.batch_size
367
+ * dp_degree
368
+ * job_config.training.gradient_accumulation_steps
369
+ )
370
+ num_tokens_per_step = global_batch_size * job_config.training.seq_len
371
+ # train loop
372
+ logger.info(f"{color.red}***** Running training *****{color.reset}")
373
+ logger.info(f"{color.green} Training starts at step {train_state.step + 1}")
374
+ logger.info(
375
+ f"{color.green} Number of tokens per sequence = {job_config.training.seq_len:,}"
376
+ )
377
+ logger.info(
378
+ f"{color.green} Gradient Accumulation steps = {job_config.training.gradient_accumulation_steps}"
379
+ )
380
+ logger.info(
381
+ f"{color.green} Instantaneous batch size (per device) = {job_config.training.batch_size:,}"
382
+ )
383
+ logger.info(
384
+ f"{color.green} Global batch size (w. parallel, distributed & accumulation) = {global_batch_size:,}"
385
+ f" ({num_tokens_per_step:,} tokens)"
386
+ )
387
+ logger.info(
388
+ f"{color.green} Total optimization steps = {job_config.training.steps:,} "
389
+ f"({job_config.training.steps * num_tokens_per_step:,} tokens)"
390
+ )
391
+ logger.info(
392
+ f"{color.green} Warmup steps = {job_config.lr_scheduler.warmup_steps:,}"
393
+ f" ({job_config.lr_scheduler.warmup_steps * num_tokens_per_step:,} tokens)"
394
+ )
395
+ logger.info(
396
+ f"{color.green} Number of parameters = {model_param_count:,} {color.reset}"
397
+ )
398
+
399
+ with (
400
+ maybe_enable_profiling(
401
+ job_config, global_step=train_state.step
402
+ ) as torch_profiler,
403
+ maybe_enable_memory_snapshot(
404
+ job_config, global_step=train_state.step
405
+ ) as memory_profiler,
406
+ ):
407
+ while train_state.step < job_config.training.steps:
408
+ train_state.step += 1
409
+ gc_handler.run(train_state.step)
410
+
411
+ optimizers.zero_grad()
412
+
413
+ losses = []
414
+ # do gradient accumulation if enabled
415
+ for _ in range(job_config.training.gradient_accumulation_steps):
416
+ # get batch
417
+ data_load_start = time.perf_counter()
418
+ batch = next(data_iterator)
419
+ input_ids, labels = batch["input_ids"], batch["labels"]
420
+
421
+ # Update metrics processor state before forward/backward
422
+ metric_logger.ntokens_since_last_log += labels.numel()
423
+ metric_logger.data_loading_times.append(
424
+ time.perf_counter() - data_load_start
425
+ )
426
+
427
+ input_ids = input_ids.to(device_type)
428
+
429
+ """
430
+ TODO[flame]: We need to carefully handle the position_ids for TP/CP
431
+ Depending on the Models'PE, the position_ids might be different.
432
+
433
+ e.g. for TP
434
+ For RoPE, all ranks have the same position_ids. [FOR HF model]
435
+ For sinusoidal, each rank has the coresponding chunked position_ids. [FOR HF model]
436
+
437
+ e.g. for CP, [optional_context_parallel_ctx shoudl automatically distbute the position_ids]
438
+ Each rank has the coresponding chunked position_ids. [FOR All model]
439
+
440
+ """
441
+ labels = labels.to(device_type)
442
+ cu_seqlens = (
443
+ batch["cu_seqlens"].to(device_type)
444
+ if "cu_seqlens" in batch
445
+ else None
446
+ )
447
+ if cu_seqlens is not None:
448
+ position_ids = prepare_position_ids(cu_seqlens).to(torch.int32)
449
+ else:
450
+ position_ids = (
451
+ torch.arange(0, input_ids.shape[1], device=device_type)
452
+ .repeat(input_ids.shape[0], 1)
453
+ .to(torch.int32)
454
+ )
455
+ # apply context parallelism if cp is enabled
456
+ # ensure CP handles the separate freqs_cis buffer for each pp stage
457
+ optional_context_parallel_ctx = (
458
+ dist_utils.create_context_parallel_ctx(
459
+ cp_mesh=world_mesh["cp"],
460
+ cp_buffers=[input_ids, labels, position_ids],
461
+ cp_seq_dims=[1, 1, 1],
462
+ cp_no_restore_buffers={input_ids, labels, position_ids},
463
+ cp_rotate_method=job_config.experimental.context_parallel_rotate_method,
464
+ )
465
+ if parallel_dims.cp_enabled
466
+ else None
467
+ )
468
+
469
+ # #! TODO[flame], we should distribute the position_ids as well with CP
470
+ if parallel_dims.pp_enabled:
471
+ raise NotImplementedError(
472
+ "Pipeline parallelism is not supported in this version"
473
+ )
474
+ # Pipeline Parallel forward / backward inside step() call
475
+ with train_context(optional_context_parallel_ctx):
476
+ targets, losses = (
477
+ (labels, []) if has_last_stage else (None, None)
478
+ )
479
+
480
+ if has_first_stage:
481
+ pp_schedule.step(input_ids, target=targets, losses=losses)
482
+ else:
483
+ pp_schedule.step(target=targets, losses=losses)
484
+
485
+ # accumulate losses across pipeline microbatches
486
+ # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
487
+ loss = (
488
+ torch.mean(torch.stack(losses)).to(device)
489
+ if has_last_stage
490
+ else torch.tensor([-1.0], device=device)
491
+ )
492
+ else:
493
+ # Non-PP forward / backward
494
+ with train_context(optional_context_parallel_ctx):
495
+ with maybe_enable_amp:
496
+ output = model(
497
+ input_ids=input_ids,
498
+ labels=labels,
499
+ position_ids=position_ids,
500
+ cu_seqlens=cu_seqlens,
501
+ )
502
+ loss = (
503
+ output.loss
504
+ / job_config.training.gradient_accumulation_steps
505
+ )
506
+ loss.backward()
507
+
508
+ losses.append(loss)
509
+ loss = sum(losses)
510
+
511
+ # clip gradients
512
+ grad_norm = dist_utils.clip_grad_norm_(
513
+ [p for m in model_parts for p in m.parameters()],
514
+ job_config.training.max_norm,
515
+ foreach=True,
516
+ pp_mesh=pp_mesh if parallel_dims.pp_enabled else None,
517
+ )
518
+
519
+ # optimizer step
520
+ checkpoint.maybe_wait_for_staging()
521
+ if job_config.training.skip_nan_inf and (
522
+ grad_norm.isnan() or grad_norm.isinf()
523
+ ):
524
+ logger.warning(
525
+ f"Skipping optimizer step - detected invalid gradient norm: {grad_norm:.4f}"
526
+ )
527
+ optimizers.zero_grad()
528
+ train_state.skipped_step += 1
529
+ else:
530
+ optimizers.step()
531
+ lr_schedulers.step()
532
+
533
+ # log metrics - Use MetricsProcessor
534
+ if metric_logger.should_log(train_state.step):
535
+ if (
536
+ parallel_dims.dp_replicate_enabled
537
+ or parallel_dims.dp_shard_enabled
538
+ or parallel_dims.cp_enabled
539
+ ):
540
+ loss = loss.detach()
541
+ # Use dist_mean/max on the accumulated loss for the step
542
+ global_avg_loss, global_max_loss = (
543
+ dist_utils.dist_mean(
544
+ loss,
545
+ world_mesh["dp_cp"],
546
+ ),
547
+ dist_utils.dist_max(
548
+ loss,
549
+ world_mesh["dp_cp"],
550
+ ),
551
+ )
552
+ else:
553
+ # Scale back the loss before logging
554
+ global_avg_loss = global_max_loss = loss.item()
555
+
556
+ # Update train state tokens and elapsed time
557
+ time_now = time.perf_counter()
558
+ time_delta = (
559
+ time_now - metric_logger.time_last_log
560
+ ) # Use metric_logger's time
561
+ train_state.token += (
562
+ metric_logger.ntokens_since_last_log # Use tokens tracked by metric_logger
563
+ * parallel_dims.world_size
564
+ / parallel_dims.non_data_parallel_size
565
+ )
566
+ train_state.elapsed += timedelta(seconds=time_delta)
567
+ train_state.log_steps.append(train_state.step)
568
+ train_state.global_avg_losses.append(global_avg_loss)
569
+ train_state.global_max_losses.append(global_max_loss)
570
+
571
+ # Log using the metric processor
572
+ last_lr = lr_schedulers.schedulers[0].get_last_lr()[0]
573
+ eta = (
574
+ train_state.elapsed
575
+ * (job_config.training.steps - train_state.step)
576
+ / train_state.step
577
+ )
578
+ metric_logger.log(
579
+ train_state.step,
580
+ global_avg_loss,
581
+ global_max_loss,
582
+ extra_metrics={
583
+ "optimizer/lr": last_lr,
584
+ "optimizer/grad_norm": grad_norm.item(),
585
+ "optimizer/skipped_step": train_state.skipped_step,
586
+ },
587
+ )
588
+
589
+ logger.info(
590
+ f"{color.blue}lr: {last_lr:.4e} gnorm: {grad_norm:5.2f} "
591
+ f"{color.magenta}[{str(train_state.elapsed).split('.')[0]:>8}<{str(eta).split('.')[0]:>8}]{color.reset}"
592
+ )
593
+
594
+ checkpoint.save(
595
+ train_state.step, force=(train_state.step == job_config.training.steps)
596
+ )
597
+
598
+ # signal the profiler that the next profiling step has started
599
+ if torch_profiler:
600
+ torch_profiler.step()
601
+ if memory_profiler:
602
+ memory_profiler.step()
603
+
604
+ # reduce timeout after first train step for faster signal
605
+ # (assuming lazy init and compilation are finished)
606
+ if train_state.step == 1:
607
+ dist_utils.set_pg_timeouts(
608
+ timeout=timedelta(seconds=job_config.comm.train_timeout_seconds),
609
+ world_mesh=world_mesh,
610
+ )
611
+
612
+ if torch.distributed.get_rank() == 0:
613
+ logger.info("Sleeping 2 seconds for other ranks to complete")
614
+ time.sleep(2)
615
+
616
+ metric_logger.close()
617
+ logger.info("Training completed")
618
+
619
+
620
+ if __name__ == "__main__":
621
+ init_logger()
622
+ config = JobConfig()
623
+ config.parse_args()
624
+ main(config)
625
+ torch.distributed.destroy_process_group()
flame/train_restart.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import os
9
+ import time
10
+ from datetime import timedelta
11
+
12
+ import fla # noqa
13
+ import torch
14
+ from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss
15
+ from fla.ops.utils import prepare_position_ids
16
+ from torch.distributed.elastic.multiprocessing.errors import record
17
+ from torchtitan.components.checkpoint import CheckpointManager
18
+ from torchtitan.components.ft import FTParallelDims, init_ft_manager
19
+ from torchtitan.components.loss import build_cross_entropy_loss
20
+ from torchtitan.components.lr_scheduler import build_lr_schedulers
21
+ from torchtitan.components.metrics import build_device_memory_monitor, build_metrics_processor, ensure_pp_loss_visible
22
+ from torchtitan.components.optimizer import build_optimizers
23
+ from torchtitan.distributed import ParallelDims
24
+ from torchtitan.distributed import utils as dist_utils
25
+ from torchtitan.protocols.model_converter import build_model_converters
26
+ from torchtitan.protocols.train_spec import TrainSpec, get_train_spec, register_train_spec
27
+ from torchtitan.tools import utils
28
+ from torchtitan.tools.logging import init_logger, logger
29
+ from torchtitan.tools.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
30
+
31
+ import custom_models
32
+ from flame.components.checkpoint import TrainState
33
+ from flame.config_manager import JobConfig
34
+ from flame.data import build_dataloader, build_dataset
35
+ from flame.models.parallelize_fla import parallelize_fla
36
+ from flame.models.pipeline_fla import pipeline_fla
37
+ from flame.tools.utils import get_nparams_and_flops
38
+
39
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
40
+ from fla.models import HamiltonForCausalLM as NewModelForCausalLM, HamiltonConfig as NewConfig
41
+ # from fla.models import GLAForCausalLM as NewModelForCausalLM, GLAConfig as NewConfig
42
+ MODEL_TYPE = NewConfig.model_type
43
+
44
+ def build_tokenizer(job_config: JobConfig) -> AutoTokenizer:
45
+ return AutoTokenizer.from_pretrained(job_config.model.tokenizer_path)
46
+
47
+
48
+ register_train_spec(
49
+ TrainSpec(
50
+ name="fla",
51
+ cls=AutoModelForCausalLM,
52
+ config=AutoConfig,
53
+ parallelize_fn=parallelize_fla,
54
+ pipelining_fn=pipeline_fla,
55
+ build_optimizers_fn=build_optimizers,
56
+ build_lr_schedulers_fn=build_lr_schedulers,
57
+ build_dataloader_fn=build_dataloader,
58
+ build_tokenizer_fn=build_tokenizer,
59
+ build_loss_fn=build_cross_entropy_loss,
60
+ )
61
+ )
62
+
63
+
64
+ # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
65
+ @record
66
+ def main(job_config: JobConfig):
67
+ logger.info(f"Starting job: {job_config.job.description}")
68
+ logger.info(f"Registering model type: {MODEL_TYPE}")
69
+
70
+ if job_config.experimental.custom_model_path:
71
+ utils.import_module_from_path(job_config.experimental.custom_model_path)
72
+
73
+ # used for colorful printing
74
+ color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color
75
+
76
+ if job_config.job.print_args:
77
+ logger.info(
78
+ f"{color.green}{json.dumps(job_config.to_dict(), indent=2, sort_keys=True)}{color.reset}"
79
+ )
80
+
81
+ # take control of garbage collection to avoid stragglers
82
+ gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)
83
+
84
+ device_module, device_type = utils.device_module, utils.device_type
85
+ device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
86
+ # Device has to be set before creating TorchFT manager.
87
+ device_module.set_device(device)
88
+ ft_manager = init_ft_manager(job_config)
89
+
90
+ # init distributed
91
+ world_size = int(os.environ["WORLD_SIZE"])
92
+ if not ft_manager.enabled:
93
+ parallel_dims = ParallelDims(
94
+ dp_shard=job_config.training.data_parallel_shard_degree,
95
+ dp_replicate=job_config.training.data_parallel_replicate_degree,
96
+ cp=job_config.experimental.context_parallel_degree,
97
+ tp=job_config.training.tensor_parallel_degree,
98
+ pp=job_config.experimental.pipeline_parallel_degree,
99
+ world_size=world_size,
100
+ enable_loss_parallel=not job_config.training.disable_loss_parallel,
101
+ )
102
+ else:
103
+ parallel_dims = FTParallelDims(
104
+ dp_shard=job_config.training.data_parallel_shard_degree,
105
+ dp_replicate=job_config.training.data_parallel_replicate_degree,
106
+ cp=job_config.experimental.context_parallel_degree,
107
+ tp=job_config.training.tensor_parallel_degree,
108
+ pp=job_config.experimental.pipeline_parallel_degree,
109
+ world_size=world_size,
110
+ enable_loss_parallel=not job_config.training.disable_loss_parallel,
111
+ ft_manager=ft_manager,
112
+ )
113
+ dist_utils.init_distributed(job_config)
114
+ # initialize device memory monitor and get peak flops for MFU calculation
115
+ device_memory_monitor = build_device_memory_monitor()
116
+ gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name)
117
+ logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}")
118
+
119
+ # build meshes
120
+ world_mesh = parallel_dims.build_mesh(device_type=device_type)
121
+ if parallel_dims.dp_enabled:
122
+ dp_mesh = world_mesh["dp"]
123
+ dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
124
+ else:
125
+ dp_degree, dp_rank = 1, 0
126
+
127
+ if parallel_dims.pp_enabled:
128
+ raise NotImplementedError(
129
+ "Pipeline parallelism is not supported in this version"
130
+ )
131
+ """
132
+ ! TODO[flame]: We need to fix the pipeline parallelism for flame
133
+ [x] Match the key of models' components with the actual naming
134
+ [ ] Fix the post-init and tie-embedding for pipeline parallelism, HF's transformer automatically
135
+ forces to tie if head is None, we need to handle this case
136
+ [ ]
137
+ """
138
+ pp_mesh = world_mesh["pp"]
139
+
140
+ # Set random seed, and maybe enable deterministic mode (mainly for debugging, expect perf loss)
141
+ dist_utils.set_determinism(
142
+ world_mesh, device, job_config.training.seed, job_config.training.deterministic
143
+ )
144
+ train_spec = get_train_spec(job_config.model.name)
145
+
146
+ logger.info("Loading tokenizer...")
147
+ tokenizer = AutoTokenizer.from_pretrained(
148
+ job_config.model.tokenizer_path,
149
+ trust_remote_code=True,
150
+ model_max_length=int(1e10),
151
+ )
152
+ logger.info(f"{tokenizer}")
153
+ logger.info(
154
+ f"Loading dataset {job_config.training.dataset}"
155
+ f":{job_config.training.dataset_name}"
156
+ if job_config.training.dataset_name is not None
157
+ else ""
158
+ )
159
+ dataset = build_dataset(
160
+ dataset=job_config.training.dataset,
161
+ dataset_name=job_config.training.dataset_name,
162
+ dataset_split=job_config.training.dataset_split,
163
+ data_dir=job_config.training.data_dir,
164
+ data_files=job_config.training.data_files,
165
+ data_probs=job_config.training.data_probs,
166
+ streaming=job_config.training.streaming,
167
+ dp_degree=dp_degree,
168
+ num_workers=job_config.training.num_workers,
169
+ seed=job_config.training.seed,
170
+ )
171
+
172
+ logger.info("Building dataloader...")
173
+ dataloader = build_dataloader(
174
+ dataset=dataset,
175
+ tokenizer=tokenizer,
176
+ rank=dp_rank,
177
+ world_size=dp_degree,
178
+ batch_size=job_config.training.batch_size,
179
+ seq_len=job_config.training.seq_len,
180
+ context_len=job_config.training.context_len,
181
+ varlen=job_config.training.varlen,
182
+ num_workers=job_config.training.num_workers,
183
+ pin_memory=job_config.training.pin_memory,
184
+ persistent_workers=job_config.training.persistent_workers,
185
+ snapshot_every_n_steps=job_config.checkpoint.interval,
186
+ )
187
+
188
+
189
+ logger.info(f"Loading model config from {job_config.model.config}")
190
+ logger.info(f"Registering model type: {MODEL_TYPE}")
191
+ AutoConfig.register(MODEL_TYPE, NewConfig) # important!
192
+ AutoModelForCausalLM.register(NewConfig, NewModelForCausalLM) # important!
193
+ model_config = AutoConfig.from_pretrained(job_config.model.config)
194
+ # set the model configs from training inputs:
195
+ # 1. norm type to decide which norm layer to use
196
+ # 2. disable fused norm if TP is enabled
197
+ # 3. vocab size from tokenizer
198
+ # 4. context_len base on inputs
199
+ if parallel_dims.tp_enabled:
200
+ if model_config.fuse_norm:
201
+ logger.warning(
202
+ f"{color.red}"
203
+ f"Fused norm is not compatible with tensor parallelism. "
204
+ f"Disabling it for now."
205
+ f"{color.reset}"
206
+ )
207
+ model_config.fuse_norm = False
208
+ if parallel_dims.loss_parallel_enabled:
209
+ if model_config.fuse_linear_cross_entropy:
210
+ logger.warning(
211
+ f"{color.red}"
212
+ f"Loss parallel enabled. Disabling fused cross entropy for now."
213
+ f"{color.reset}"
214
+ )
215
+ model_config.fuse_linear_cross_entropy = False
216
+ model_config.vocab_size = max(tokenizer.vocab_size, model_config.vocab_size)
217
+
218
+
219
+
220
+ logger.info(
221
+ f"Building model from the config\n{color.green}{model_config}{color.reset}"
222
+ )
223
+ with torch.device("meta"):
224
+ model = AutoModelForCausalLM.from_config(model_config)
225
+ if (
226
+ getattr(model_config, "fuse_linear_cross_entropy", False)
227
+ and FusedLinearCrossEntropyLoss is not None
228
+ ):
229
+ model.criterion = FusedLinearCrossEntropyLoss(
230
+ num_chunks=8 // parallel_dims.tp
231
+ )
232
+ # defer weight initialization until after parallelisms are applied
233
+ model.apply(lambda m: setattr(m, "_is_hf_initialized", False))
234
+ logger.info(f"{color.blue}\n{model}{color.reset}\n")
235
+
236
+ logger.info("Applying model converters...")
237
+
238
+ # Build the collection of model converters. No-op if `model.converters` empty
239
+ model_converters = build_model_converters(job_config, parallel_dims)
240
+ model_converters.convert(model)
241
+
242
+ # calculate model size and flops per token
243
+ model_param_count, num_flops_per_token = get_nparams_and_flops(
244
+ model, model_config, job_config.training.context_len
245
+ )
246
+
247
+ # move sharded model to CPU/GPU and initialize weights via DTensor
248
+ if job_config.checkpoint.create_seed_checkpoint:
249
+ init_device = "cpu"
250
+ elif job_config.training.enable_cpu_offload:
251
+ init_device = "cpu"
252
+ else:
253
+ init_device = device_type
254
+
255
+ # apply parallelisms and initialization
256
+ if parallel_dims.pp_enabled:
257
+ # apply PT-D Pipeline Parallel
258
+ (
259
+ pp_schedule,
260
+ model_parts,
261
+ has_first_stage,
262
+ has_last_stage,
263
+ ) = train_spec.pipelining_fn(
264
+ model,
265
+ pp_mesh,
266
+ parallel_dims,
267
+ job_config,
268
+ device,
269
+ model_config,
270
+ train_spec.loss_fn,
271
+ )
272
+ # when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
273
+ del model
274
+
275
+ # For PP with looped schedules, each item in model_parts is one stage-model-chunk.
276
+ # We need to iterate through model_parts to apply SPMD parallelisms, compilation,
277
+ # optimizer, and checkpointing
278
+ for m in model_parts:
279
+ # apply SPMD-style PT-D techniques
280
+ train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config)
281
+ m.to_empty(device=init_device)
282
+ with torch.no_grad():
283
+ m.post_init()
284
+ m.train()
285
+
286
+ # confirm that user will be able to view loss metrics on the console
287
+ ensure_pp_loss_visible(parallel_dims, job_config, color)
288
+ else:
289
+ # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
290
+ train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
291
+ model.to_empty(device=init_device)
292
+ with torch.no_grad():
293
+ model.post_init()
294
+ model.train()
295
+
296
+ model_parts = [model]
297
+
298
+ device_mem_stats = device_memory_monitor.get_peak_stats()
299
+ logger.info(
300
+ f"{device_type.upper()} memory usage for model: "
301
+ f"{device_mem_stats.max_reserved_gib:.2f}GiB"
302
+ f"({device_mem_stats.max_reserved_pct:.2f}%)"
303
+ )
304
+
305
+ # build optimizer after applying parallelisms to the model
306
+ optimizers = train_spec.build_optimizers_fn(model_parts, job_config, ft_manager)
307
+ lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config)
308
+ # Post optimizer step model converters hook.
309
+ # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
310
+ # where it issues a single all-reduce for all parameters at once for better performance
311
+ optimizers.register_step_post_hook(
312
+ lambda *args, **kwargs: model_converters.post_optimizer_hook(model_parts)
313
+ )
314
+
315
+ train_state = TrainState()
316
+
317
+ # load initial checkpoint
318
+ checkpoint = CheckpointManager(
319
+ dataloader=dataloader,
320
+ model_parts=model_parts,
321
+ optimizers=optimizers,
322
+ lr_schedulers=lr_schedulers,
323
+ states={"train_state": train_state},
324
+ job_config=job_config,
325
+ ft_manager=ft_manager,
326
+ )
327
+
328
+ # if job_config.training.streaming and job_config.checkpoint.enable_checkpoint:
329
+ # checkpoint = CheckpointManager(
330
+ # dataloader=None, # 使用变量
331
+ # model_parts=model_parts,
332
+ # optimizers=optimizers,
333
+ # lr_schedulers=lr_schedulers,
334
+ # states={"train_state": train_state},
335
+ # job_config=job_config,
336
+ # ft_manager=ft_manager,
337
+ # )
338
+ # if hasattr(checkpoint, 'states') and 'dataloader' in checkpoint.states:
339
+ # print("[Fix] Manually removing 'dataloader' from checkpoint states to avoid Missing Key error.")
340
+ # del checkpoint.states['dataloader']
341
+ # else:
342
+ # checkpoint = CheckpointManager(
343
+ # dataloader=dataloader,
344
+ # model_parts=model_parts,
345
+ # optimizers=optimizers,
346
+ # lr_schedulers=lr_schedulers,
347
+ # states={"train_state": train_state},
348
+ # job_config=job_config,
349
+ # ft_manager=ft_manager,
350
+ # )
351
+
352
+
353
+ if job_config.checkpoint.create_seed_checkpoint:
354
+ assert world_size == 1, (
355
+ "Must create seed checkpoint using a single device, to disable sharding"
356
+ )
357
+ assert job_config.checkpoint.enable_checkpoint, (
358
+ "Must enable checkpointing when creating a seed checkpoint"
359
+ )
360
+ checkpoint.save(curr_step=0, force=True)
361
+ logger.info("Created seed checkpoint")
362
+ return
363
+
364
+ logger.info(job_config.checkpoint)
365
+ checkpoint.load(step=job_config.checkpoint.load_step)
366
+ metric_logger = build_metrics_processor(job_config, parallel_dims)
367
+
368
+ # Set dependent attributes for metric_logger
369
+ metric_logger.num_flops_per_token = num_flops_per_token
370
+ metric_logger.optimizers = optimizers # Pass optimizers if needed by logger logic
371
+ metric_logger.lr_schedulers = (
372
+ lr_schedulers # Pass schedulers if needed by logger logic
373
+ )
374
+
375
+ # plot losses loaded from checkpoint (if any) to TensorBoard
376
+ # NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
377
+ # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq
378
+ if train_state.step > 0 and len(metric_logger.data_loading_times) > 0:
379
+ for idx, step in enumerate(train_state.log_steps):
380
+ metric_logger.log(
381
+ step,
382
+ global_avg_loss=train_state.global_avg_losses[idx],
383
+ global_max_loss=train_state.global_max_losses[idx],
384
+ )
385
+
386
+ data_iterator = iter(dataloader)
387
+
388
+ # if job_config.training.streaming and train_state.step > 0:
389
+ # from tqdm import tqdm
390
+ # import gc
391
+ # skip_count = 14
392
+ # # skip_count = train_state.step * job_config.training.gradient_accumulation_steps
393
+ # local_rank = torch.distributed.get_rank()
394
+ # if local_rank == 0:
395
+ # logger.info(f"Streaming Resume: Skipping {skip_count} micro-batches to catch up...")
396
+ # for i in tqdm(range(skip_count),
397
+ # desc="Skipping Data",
398
+ # unit="batch",
399
+ # disable=(local_rank != 0),
400
+ # dynamic_ncols=True):
401
+ # try:
402
+ # batch = next(data_iterator)
403
+ # del batch
404
+ # except StopIteration:
405
+ # if local_rank == 0:
406
+ # logger.warning(f"Data iterator exhausted before finishing skip at step {_}!")
407
+ # break
408
+
409
+ # if i % 500 == 0:
410
+ # gc.collect()
411
+ # gc.collect()
412
+
413
+ # if local_rank == 0:
414
+ # logger.info("Data skipping completed. Resuming training...")
415
+ # =================================================================
416
+
417
+
418
+
419
+
420
+ train_context = dist_utils.get_train_context(
421
+ parallel_dims.loss_parallel_enabled,
422
+ job_config.experimental.enable_compiled_autograd,
423
+ )
424
+ maybe_enable_amp = dist_utils.maybe_enable_amp(
425
+ parallel_dims,
426
+ job_config.training.mixed_precision_param,
427
+ device_type,
428
+ )
429
+
430
+ # variables used to keep info for metrics logging
431
+ device_memory_monitor.reset_peak_stats()
432
+
433
+ global_batch_size = (
434
+ job_config.training.batch_size
435
+ * dp_degree
436
+ * job_config.training.gradient_accumulation_steps
437
+ )
438
+ num_tokens_per_step = global_batch_size * job_config.training.seq_len
439
+ # train loop
440
+ logger.info(f"{color.red}***** Running training *****{color.reset}")
441
+ logger.info(f"{color.green} Training starts at step {train_state.step + 1}")
442
+ logger.info(
443
+ f"{color.green} Number of tokens per sequence = {job_config.training.seq_len:,}"
444
+ )
445
+ logger.info(
446
+ f"{color.green} Gradient Accumulation steps = {job_config.training.gradient_accumulation_steps}"
447
+ )
448
+ logger.info(
449
+ f"{color.green} Instantaneous batch size (per device) = {job_config.training.batch_size:,}"
450
+ )
451
+ logger.info(
452
+ f"{color.green} Global batch size (w. parallel, distributed & accumulation) = {global_batch_size:,}"
453
+ f" ({num_tokens_per_step:,} tokens)"
454
+ )
455
+ logger.info(
456
+ f"{color.green} Total optimization steps = {job_config.training.steps:,} "
457
+ f"({job_config.training.steps * num_tokens_per_step:,} tokens)"
458
+ )
459
+ logger.info(
460
+ f"{color.green} Warmup steps = {job_config.lr_scheduler.warmup_steps:,}"
461
+ f" ({job_config.lr_scheduler.warmup_steps * num_tokens_per_step:,} tokens)"
462
+ )
463
+ logger.info(
464
+ f"{color.green} Number of parameters = {model_param_count:,} {color.reset}"
465
+ )
466
+
467
+ with (
468
+ maybe_enable_profiling(
469
+ job_config, global_step=train_state.step
470
+ ) as torch_profiler,
471
+ maybe_enable_memory_snapshot(
472
+ job_config, global_step=train_state.step
473
+ ) as memory_profiler,
474
+ ):
475
+ while train_state.step < job_config.training.steps:
476
+ train_state.step += 1
477
+ gc_handler.run(train_state.step)
478
+
479
+ optimizers.zero_grad()
480
+
481
+ losses = []
482
+ # do gradient accumulation if enabled
483
+ for _ in range(job_config.training.gradient_accumulation_steps):
484
+ # get batch
485
+ data_load_start = time.perf_counter()
486
+ batch = next(data_iterator)
487
+ input_ids, labels = batch["input_ids"], batch["labels"]
488
+
489
+ # Update metrics processor state before forward/backward
490
+ metric_logger.ntokens_since_last_log += labels.numel()
491
+ metric_logger.data_loading_times.append(
492
+ time.perf_counter() - data_load_start
493
+ )
494
+
495
+ input_ids = input_ids.to(device_type)
496
+
497
+ """
498
+ TODO[flame]: We need to carefully handle the position_ids for TP/CP
499
+ Depending on the Models'PE, the position_ids might be different.
500
+
501
+ e.g. for TP
502
+ For RoPE, all ranks have the same position_ids. [FOR HF model]
503
+ For sinusoidal, each rank has the coresponding chunked position_ids. [FOR HF model]
504
+
505
+ e.g. for CP, [optional_context_parallel_ctx shoudl automatically distbute the position_ids]
506
+ Each rank has the coresponding chunked position_ids. [FOR All model]
507
+
508
+ """
509
+ labels = labels.to(device_type)
510
+ cu_seqlens = (
511
+ batch["cu_seqlens"].to(device_type)
512
+ if "cu_seqlens" in batch
513
+ else None
514
+ )
515
+ if cu_seqlens is not None:
516
+ position_ids = prepare_position_ids(cu_seqlens).to(torch.int32)
517
+ else:
518
+ position_ids = (
519
+ torch.arange(0, input_ids.shape[1], device=device_type)
520
+ .repeat(input_ids.shape[0], 1)
521
+ .to(torch.int32)
522
+ )
523
+ # apply context parallelism if cp is enabled
524
+ # ensure CP handles the separate freqs_cis buffer for each pp stage
525
+ optional_context_parallel_ctx = (
526
+ dist_utils.create_context_parallel_ctx(
527
+ cp_mesh=world_mesh["cp"],
528
+ cp_buffers=[input_ids, labels, position_ids],
529
+ cp_seq_dims=[1, 1, 1],
530
+ cp_no_restore_buffers={input_ids, labels, position_ids},
531
+ cp_rotate_method=job_config.experimental.context_parallel_rotate_method,
532
+ )
533
+ if parallel_dims.cp_enabled
534
+ else None
535
+ )
536
+
537
+ # #! TODO[flame], we should distribute the position_ids as well with CP
538
+ if parallel_dims.pp_enabled:
539
+ raise NotImplementedError(
540
+ "Pipeline parallelism is not supported in this version"
541
+ )
542
+ # Pipeline Parallel forward / backward inside step() call
543
+ with train_context(optional_context_parallel_ctx):
544
+ targets, losses = (
545
+ (labels, []) if has_last_stage else (None, None)
546
+ )
547
+
548
+ if has_first_stage:
549
+ pp_schedule.step(input_ids, target=targets, losses=losses)
550
+ else:
551
+ pp_schedule.step(target=targets, losses=losses)
552
+
553
+ # accumulate losses across pipeline microbatches
554
+ # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
555
+ loss = (
556
+ torch.mean(torch.stack(losses)).to(device)
557
+ if has_last_stage
558
+ else torch.tensor([-1.0], device=device)
559
+ )
560
+ else:
561
+ # Non-PP forward / backward
562
+ with train_context(optional_context_parallel_ctx):
563
+ with maybe_enable_amp:
564
+ output = model(
565
+ input_ids=input_ids,
566
+ labels=labels,
567
+ position_ids=position_ids,
568
+ cu_seqlens=cu_seqlens,
569
+ )
570
+ loss = (
571
+ output.loss
572
+ / job_config.training.gradient_accumulation_steps
573
+ )
574
+ loss.backward()
575
+ # print('--------------------------')
576
+
577
+ losses.append(loss)
578
+ loss = sum(losses)
579
+
580
+ # clip gradients
581
+ grad_norm = dist_utils.clip_grad_norm_(
582
+ [p for m in model_parts for p in m.parameters()],
583
+ job_config.training.max_norm,
584
+ foreach=True,
585
+ pp_mesh=pp_mesh if parallel_dims.pp_enabled else None,
586
+ )
587
+
588
+ # optimizer step
589
+ checkpoint.maybe_wait_for_staging()
590
+ if job_config.training.skip_nan_inf and (
591
+ grad_norm.isnan() or grad_norm.isinf()
592
+ ):
593
+ logger.warning(
594
+ f"Skipping optimizer step - detected invalid gradient norm: {grad_norm:.4f}"
595
+ )
596
+ optimizers.zero_grad()
597
+ train_state.skipped_step += 1
598
+ else:
599
+ optimizers.step()
600
+ lr_schedulers.step()
601
+
602
+ # log metrics - Use MetricsProcessor
603
+ if metric_logger.should_log(train_state.step):
604
+ if (
605
+ parallel_dims.dp_replicate_enabled
606
+ or parallel_dims.dp_shard_enabled
607
+ or parallel_dims.cp_enabled
608
+ ):
609
+ loss = loss.detach()
610
+ # Use dist_mean/max on the accumulated loss for the step
611
+ global_avg_loss, global_max_loss = (
612
+ dist_utils.dist_mean(
613
+ loss,
614
+ world_mesh["dp_cp"],
615
+ ),
616
+ dist_utils.dist_max(
617
+ loss,
618
+ world_mesh["dp_cp"],
619
+ ),
620
+ )
621
+ else:
622
+ # Scale back the loss before logging
623
+ global_avg_loss = global_max_loss = loss.item()
624
+
625
+ # Update train state tokens and elapsed time
626
+ time_now = time.perf_counter()
627
+ time_delta = (
628
+ time_now - metric_logger.time_last_log
629
+ ) # Use metric_logger's time
630
+ train_state.token += (
631
+ metric_logger.ntokens_since_last_log # Use tokens tracked by metric_logger
632
+ * parallel_dims.world_size
633
+ / parallel_dims.non_data_parallel_size
634
+ )
635
+ train_state.elapsed += timedelta(seconds=time_delta)
636
+ train_state.log_steps.append(train_state.step)
637
+ train_state.global_avg_losses.append(global_avg_loss)
638
+ train_state.global_max_losses.append(global_max_loss)
639
+
640
+ # Log using the metric processor
641
+ last_lr = lr_schedulers.schedulers[0].get_last_lr()[0]
642
+ eta = (
643
+ train_state.elapsed
644
+ * (job_config.training.steps - train_state.step)
645
+ / train_state.step
646
+ )
647
+ metric_logger.log(
648
+ train_state.step,
649
+ global_avg_loss,
650
+ global_max_loss,
651
+ extra_metrics={
652
+ "optimizer/lr": last_lr,
653
+ "optimizer/grad_norm": grad_norm.item(),
654
+ "optimizer/skipped_step": train_state.skipped_step,
655
+ },
656
+ )
657
+
658
+ logger.info(
659
+ f"{color.blue}lr: {last_lr:.4e} gnorm: {grad_norm:5.2f} "
660
+ f"{color.magenta}[{str(train_state.elapsed).split('.')[0]:>8}<{str(eta).split('.')[0]:>8}]{color.reset}"
661
+ )
662
+
663
+ checkpoint.save(
664
+ train_state.step, force=(train_state.step == job_config.training.steps)
665
+ )
666
+
667
+ # signal the profiler that the next profiling step has started
668
+ if torch_profiler:
669
+ torch_profiler.step()
670
+ if memory_profiler:
671
+ memory_profiler.step()
672
+
673
+ # reduce timeout after first train step for faster signal
674
+ # (assuming lazy init and compilation are finished)
675
+ if train_state.step == 1:
676
+ dist_utils.set_pg_timeouts(
677
+ timeout=timedelta(seconds=job_config.comm.train_timeout_seconds),
678
+ world_mesh=world_mesh,
679
+ )
680
+
681
+ if torch.distributed.get_rank() == 0:
682
+ logger.info("Sleeping 2 seconds for other ranks to complete")
683
+ time.sleep(2)
684
+
685
+ metric_logger.close()
686
+ logger.info("Training completed")
687
+
688
+
689
+ if __name__ == "__main__":
690
+ init_logger()
691
+ config = JobConfig()
692
+ config.parse_args()
693
+ main(config)
694
+ torch.distributed.destroy_process_group()
flame/utils/__init__.py ADDED
File without changes
flame/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (138 Bytes). View file
 
flame/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (154 Bytes). View file
 
flame/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (142 Bytes). View file
 
flame/utils/__pycache__/convert_dcp_to_hf.cpython-310.pyc ADDED
Binary file (2.11 kB). View file
 
flame/utils/__pycache__/convert_dcp_to_hf.cpython-311.pyc ADDED
Binary file (4.46 kB). View file
 
flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc ADDED
Binary file (4.06 kB). View file
 
flame/utils/convert_dcp_to_hf.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import argparse
5
+ import io
6
+ import os
7
+ import tempfile
8
+ from datetime import timedelta
9
+
10
+ import fla # noqa
11
+ import torch
12
+ import torch.serialization
13
+ from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
14
+ from torchtitan.tools.logging import init_logger, logger
15
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
16
+
17
+ from fla.models import HamiltonForCausalLM as NewModelForCausalLM, HamiltonConfig as NewConfig
18
+ MODEL_TYPE = NewConfig.model_type
19
+
20
+ # import custom_models
21
+
22
+
23
+ @torch.inference_mode()
24
+ def save_pretrained(
25
+ path: str,
26
+ step: int,
27
+ config: str,
28
+ tokenizer: str
29
+ ):
30
+ logger.info(f"Loading the config from {config}")
31
+
32
+
33
+ AutoConfig.register(MODEL_TYPE, NewConfig) # important!
34
+ config = AutoConfig.from_pretrained(config, trust_remote_code=True)
35
+
36
+ logger.info(f"Saving the config to {path}")
37
+ config.save_pretrained(path)
38
+ logger.info(f"Loading the tokenizer from {tokenizer}")
39
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
40
+ logger.info(f"Saving the tokenizer to {path}")
41
+ tokenizer.save_pretrained(path)
42
+
43
+ with tempfile.TemporaryDirectory() as tmpdir:
44
+ checkpoint = os.path.join(path, f'checkpoint/step-{step}')
45
+ checkpoint_path = os.path.join(tmpdir, 'checkpoint.pt')
46
+ logger.info(f"Saving the distributed checkpoint to {checkpoint_path}")
47
+ dcp_to_torch_save(checkpoint, checkpoint_path)
48
+
49
+ logger.info(f"Initializing the model from config\n{config}")
50
+ # model = AutoModelForCausalLM.from_config(config)
51
+ AutoModelForCausalLM.register(NewConfig, NewModelForCausalLM) # important!
52
+ model = AutoModelForCausalLM.from_config(config)
53
+
54
+ logger.info(model)
55
+ logger.info("Loading state dict from the checkpoint")
56
+
57
+ # Add datetime.timedelta and io.BytesIO to safe globals
58
+ torch.serialization.add_safe_globals([timedelta, io.BytesIO])
59
+ # torch.load now with default weights_only=True will work
60
+ model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model'])
61
+
62
+ logger.info(f"Saving the model to {path}")
63
+ model.save_pretrained(path)
64
+
65
+
66
+ if __name__ == "__main__":
67
+ init_logger()
68
+ parser = argparse.ArgumentParser("Convert DCP format model weights to huggingface-style.")
69
+ parser.add_argument("--path", type=str, required=True)
70
+ parser.add_argument("--step", type=int, required=True)
71
+ parser.add_argument("--config", type=str, required=True)
72
+ parser.add_argument("--tokenizer", type=str, required=True)
73
+ args = parser.parse_args()
74
+ save_pretrained(args.path, args.step, args.config, args.tokenizer)
flame/utils/convert_hf_to_dcp.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import argparse
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ import torch.distributed.checkpoint as DCP
9
+ from transformers import AutoModelForCausalLM
10
+
11
+ import fla # noqa
12
+ from torchtitan.tools.logging import init_logger, logger
13
+
14
+
15
+ @torch.inference_mode()
16
+ def convert_hf_weights(model: str, checkpoint: str):
17
+ logger.info(f"Loading model from {model}")
18
+ model = AutoModelForCausalLM.from_pretrained(model)
19
+ state_dict = model.state_dict()
20
+
21
+ logger.info(f"Writing to DCP at '{checkpoint}'")
22
+ checkpoint.mkdir(parents=True, exist_ok=True)
23
+ storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8)
24
+ DCP.save({"model": state_dict}, storage_writer=storage_writer)
25
+
26
+
27
+ if __name__ == "__main__":
28
+ init_logger()
29
+ parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.")
30
+ parser.add_argument("--model", type=str, required=True)
31
+ parser.add_argument("--checkpoint", type=Path, required=True)
32
+ args = parser.parse_args()
33
+
34
+ convert_hf_weights(args.model, args.checkpoint)
flame/utils/preprocess.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import argparse
5
+ from typing import Any, Dict, List
6
+
7
+ from transformers import AutoTokenizer, PreTrainedTokenizer
8
+
9
+ from flame.data import build_dataset
10
+ from torchtitan.tools.logging import init_logger, logger
11
+
12
+
13
+ def tokenize(
14
+ examples: Dict[str, List[Any]],
15
+ tokenizer: PreTrainedTokenizer,
16
+ ) -> Dict:
17
+ if 'text' in examples:
18
+ samples = examples['text']
19
+ elif 'content' in examples:
20
+ samples = examples['content']
21
+ else:
22
+ raise ValueError(f'No "text" or "content" field found in examples:\n{examples}')
23
+ input_ids = tokenizer(samples)['input_ids']
24
+ bits_per_token = [len(sample.encode(encoding='utf-8')) * 8 / len(input_ids[i]) for i, sample in enumerate(samples)]
25
+ return {'input_ids': input_ids, 'bits_per_token': bits_per_token}
26
+
27
+
28
+ if __name__ == '__main__':
29
+ init_logger()
30
+ parser = argparse.ArgumentParser(description='Preprocess the dataset.')
31
+ parser.add_argument(
32
+ '--dataset',
33
+ default='HuggingFaceFW/fineweb-edu',
34
+ help='Dataset to use, with comma separated values',
35
+ )
36
+ parser.add_argument(
37
+ '--dataset_name',
38
+ default='sample-100BT',
39
+ help='The name of the dataset config, with comma separated values if provided',
40
+ )
41
+ parser.add_argument(
42
+ '--dataset_split',
43
+ default='train',
44
+ help='Dataset split to use, with comma separated values if provided',
45
+ )
46
+ parser.add_argument(
47
+ '--data_dir',
48
+ default=None,
49
+ help='Data dirs to use, with comma separated values if provided',
50
+ )
51
+ parser.add_argument(
52
+ '--data_files',
53
+ default=None,
54
+ help='Data files to use, with comma separated values if provided',
55
+ )
56
+ parser.add_argument(
57
+ '--data_probs',
58
+ default=None,
59
+ help='Data sampling probabilities, with comma separated values if provided',
60
+ )
61
+ parser.add_argument(
62
+ '--streaming',
63
+ action='store_true',
64
+ help='Whether to use streaming mode',
65
+ )
66
+ parser.add_argument(
67
+ '--num_workers',
68
+ type=int,
69
+ default=64,
70
+ help='Number of workers to use for preprocessing',
71
+ )
72
+ parser.add_argument(
73
+ '--seed',
74
+ type=int,
75
+ default=42,
76
+ help='Random seed for preprocessing',
77
+ )
78
+ parser.add_argument(
79
+ '--path',
80
+ default='data',
81
+ help='Path to save the preprocessed dataset',
82
+ )
83
+ parser.add_argument(
84
+ '--tokenizer',
85
+ default='fla-hub/transformer-1.3B-100B',
86
+ help='Tokenizer to use',
87
+ )
88
+ parser.add_argument(
89
+ "--batch_size",
90
+ type=int,
91
+ default=2048,
92
+ help="Batch size for processing"
93
+ )
94
+ args = parser.parse_args()
95
+
96
+ logger.info(f'Loading tokenizer {args.tokenizer}')
97
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
98
+ logger.info(f'{tokenizer}')
99
+ logger.info(f'Loading dataset {args.dataset} {args.dataset_name} {args.dataset_split}')
100
+ dataset = build_dataset(
101
+ dataset=args.dataset,
102
+ dataset_name=args.dataset_name,
103
+ dataset_split=args.dataset_split,
104
+ data_dir=args.data_dir,
105
+ data_files=args.data_files,
106
+ data_probs=args.data_probs,
107
+ streaming=args.streaming,
108
+ num_workers=args.num_workers,
109
+ seed=args.seed,
110
+ )
111
+ logger.info(f'Tokenizing and processing the dataset with batch size {args.batch_size}')
112
+ dataset = dataset.map(
113
+ lambda examples: tokenize(examples, tokenizer),
114
+ batched=True,
115
+ batch_size=args.batch_size,
116
+ remove_columns=list(next(iter(dataset)).keys()),
117
+ num_proc=args.num_workers,
118
+ desc="Running tokenizer on dataset"
119
+ )
120
+ logger.info(f'{dataset}')
121
+ logger.info(f'Saving tokenized dataset to {args.path}')
122
+ dataset.save_to_disk(args.path)