File size: 14,845 Bytes
cd16f07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
from __future__ import annotations

import bisect
import functools
import importlib.util
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple

import torch
from torch.utils.data import DataLoader, Dataset
from xqs_stack import choose_optimizer_backend

SHARD_INDEX_FILENAME = "shard_index.json"
SHARD_INDEX_PROGRESS_EVERY = 256


@dataclass
class TrainStackConfig:
    optimizer_name: str = "adafactor"
    learning_rate: float = 3e-4
    weight_decay: float = 0.01
    batch_size: int = 4
    grad_accum_steps: int = 1
    num_workers: int = 2
    pin_memory: bool = True
    prefetch_factor: int = 4
    persistent_workers: bool = True
    max_seq_len: int = 2048
    dataset_dir: str = ""
    use_bf16: bool = True


class PretokenizedShardDataset(Dataset):
    def __init__(self, dataset_dir: str, max_seq_len: int):
        self.root = Path(dataset_dir)
        if not self.root.exists():
            raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}")
        self.max_seq_len = max_seq_len
        self.shard_paths = sorted(self.root.glob("*.pt"))
        if not self.shard_paths:
            raise FileNotFoundError(f"No .pt shards found in {dataset_dir}")
        self.shard_sizes: List[int] = []
        self.cumulative_sizes: List[int] = []
        total = 0
        self._cached_shard_path: Optional[Path] = None
        self._cached_shard_tensor: Optional[torch.Tensor] = None
        for shard_path, shard_len in self._load_or_build_shard_index():
            total += shard_len
            self.shard_sizes.append(shard_len)
            self.cumulative_sizes.append(total)

    def _shard_index_path(self) -> Path:
        return self.root / SHARD_INDEX_FILENAME

    def _read_json_file(self, path: Path) -> Dict[str, object]:
        try:
            return json.loads(path.read_text(encoding="utf-8"))
        except (OSError, json.JSONDecodeError):
            return {}

    def _extract_index_entries(self, payload: Dict[str, object]) -> Optional[List[Tuple[Path, int]]]:
        shard_entries = payload.get("shards")
        if not isinstance(shard_entries, list):
            return None
        lengths_by_name: Dict[str, int] = {}
        for entry in shard_entries:
            if not isinstance(entry, dict):
                return None
            file_name = entry.get("file")
            length = entry.get("length")
            if not isinstance(file_name, str) or not isinstance(length, int):
                return None
            lengths_by_name[file_name] = length
        resolved: List[Tuple[Path, int]] = []
        for shard_path in self.shard_paths:
            length = lengths_by_name.get(shard_path.name)
            if length is None:
                return None
            resolved.append((shard_path, length))
        return resolved

    def _load_cached_index(self) -> Optional[List[Tuple[Path, int]]]:
        for candidate in [self._shard_index_path(), self.root / "metadata.json"]:
            if not candidate.exists():
                continue
            resolved = self._extract_index_entries(self._read_json_file(candidate))
            if resolved is not None:
                print(
                    json.dumps(
                        {
                            "event": "dataset_index_loaded",
                            "dataset_dir": str(self.root),
                            "source": candidate.name,
                            "shards": len(resolved),
                            "samples": sum(length for _, length in resolved),
                        }
                    ),
                    flush=True,
                )
                return resolved
        return None

    def _infer_shard_len(self, shard_path: Path) -> int:
        shard = torch.load(shard_path, map_location="cpu")
        if isinstance(shard, torch.Tensor):
            if shard.ndim == 2:
                return int(shard.size(0))
            return 1
        if isinstance(shard, list):
            return len(shard)
        raise TypeError(f"Unsupported shard format in {shard_path}")

    def _write_cached_index(self, entries: List[Tuple[Path, int]]) -> None:
        payload = {
            "shards": [{"file": path.name, "length": length} for path, length in entries],
            "total_samples": sum(length for _, length in entries),
        }
        self._shard_index_path().write_text(json.dumps(payload, indent=2), encoding="utf-8")

    def _load_or_build_shard_index(self) -> List[Tuple[Path, int]]:
        cached = self._load_cached_index()
        if cached is not None:
            return cached
        print(
            json.dumps(
                {
                    "event": "dataset_index_build_start",
                    "dataset_dir": str(self.root),
                    "shards": len(self.shard_paths),
                }
            ),
            flush=True,
        )
        entries: List[Tuple[Path, int]] = []
        running_total = 0
        for shard_idx, shard_path in enumerate(self.shard_paths, start=1):
            shard_len = self._infer_shard_len(shard_path)
            entries.append((shard_path, shard_len))
            running_total += shard_len
            if shard_idx % SHARD_INDEX_PROGRESS_EVERY == 0 or shard_idx == len(self.shard_paths):
                print(
                    json.dumps(
                        {
                            "event": "dataset_index_build_progress",
                            "dataset_dir": str(self.root),
                            "indexed_shards": shard_idx,
                            "total_shards": len(self.shard_paths),
                            "samples": running_total,
                        }
                    ),
                    flush=True,
                )
        self._write_cached_index(entries)
        print(
            json.dumps(
                {
                    "event": "dataset_index_build_done",
                    "dataset_dir": str(self.root),
                    "shards": len(entries),
                    "samples": running_total,
                }
            ),
            flush=True,
        )
        return entries

    def __len__(self) -> int:
        return self.cumulative_sizes[-1]

    def _load_shard(self, shard_idx: int) -> torch.Tensor:
        shard_path = self.shard_paths[shard_idx]
        if self._cached_shard_path == shard_path and self._cached_shard_tensor is not None:
            return self._cached_shard_tensor
        shard = torch.load(shard_path, map_location="cpu")
        if isinstance(shard, list):
            shard = torch.stack([torch.as_tensor(item, dtype=torch.long) for item in shard], dim=0)
        elif isinstance(shard, torch.Tensor):
            if shard.ndim == 1:
                shard = shard.unsqueeze(0)
        else:
            raise TypeError(f"Unsupported shard format in {shard_path}")
        self._cached_shard_path = shard_path
        self._cached_shard_tensor = shard
        return shard

    def __getitem__(self, idx: int) -> torch.Tensor:
        if idx < 0:
            idx += len(self)
        shard_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        shard_start = 0 if shard_idx == 0 else self.cumulative_sizes[shard_idx - 1]
        item_idx = idx - shard_start
        tokens = self._load_shard(shard_idx)[item_idx].to(dtype=torch.long)
        if tokens.numel() < 2:
            padded = torch.zeros(2, dtype=torch.long)
            padded[: tokens.numel()] = tokens
            tokens = padded
        return tokens[: self.max_seq_len + 1]


class SyntheticTokenDataset(Dataset):
    def __init__(self, vocab_size: int, max_seq_len: int, num_samples: int = 128):
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.num_samples = num_samples

    def __len__(self) -> int:
        return self.num_samples

    def __getitem__(self, idx: int) -> torch.Tensor:
        return torch.randint(0, self.vocab_size, (self.max_seq_len + 1,), dtype=torch.long)


class LayerWiseSGD(torch.optim.Optimizer):
    def __init__(self, params: Iterable[torch.nn.Parameter], lr: float = 1e-2, momentum: float = 0.9, weight_decay: float = 0.0):
        defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        for group in self.param_groups:
            lr = group["lr"]
            momentum = group["momentum"]
            weight_decay = group["weight_decay"]
            params_with_grad = [p for p in group["params"] if p.grad is not None]
            if not params_with_grad:
                continue
            device = params_with_grad[0].device
            mean_grad_sq = torch.zeros((), device=device)
            counted = 0
            for p in params_with_grad:
                grad = p.grad
                if weight_decay != 0:
                    grad = grad.add(p, alpha=weight_decay)
                mean_grad_sq = mean_grad_sq + grad.pow(2).mean()
                counted += 1
            mean_grad_sq = mean_grad_sq / max(1, counted)
            velocity = group.get("layer_velocity")
            if velocity is None:
                velocity = torch.zeros((), device=device)
            velocity = (momentum * velocity) + mean_grad_sq.sqrt()
            group["layer_velocity"] = velocity
            scale = lr / velocity.clamp(min=1e-8)
            for p in params_with_grad:
                grad = p.grad
                if weight_decay != 0:
                    grad = grad.add(p, alpha=weight_decay)
                p.add_(grad, alpha=-scale)
        return loss


def _build_adafactor(params: Iterable[torch.nn.Parameter], cfg: TrainStackConfig):
    if importlib.util.find_spec("transformers") is None:
        return torch.optim.AdamW(params, lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
    transformers = __import__("transformers")
    return transformers.Adafactor(
        params,
        lr=cfg.learning_rate,
        relative_step=False,
        scale_parameter=False,
        warmup_init=False,
        weight_decay=cfg.weight_decay,
    )


def _build_adam8bit(params: Iterable[torch.nn.Parameter], cfg: TrainStackConfig):
    if importlib.util.find_spec("bitsandbytes") is None:
        return torch.optim.AdamW(params, lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
    bnb = __import__("bitsandbytes")
    return bnb.optim.Adam8bit(params, lr=cfg.learning_rate, weight_decay=cfg.weight_decay)


def build_optimizer(model: torch.nn.Module, cfg: TrainStackConfig) -> torch.optim.Optimizer:
    name = cfg.optimizer_name.lower()
    if name == "auto":
        name = choose_optimizer_backend(prefer_low_memory=True)
    if name in {"adamw_fused", "fused_adamw"}:
        if torch.cuda.is_available():
            try:
                return torch.optim.AdamW(
                    model.parameters(),
                    lr=cfg.learning_rate,
                    weight_decay=cfg.weight_decay,
                    fused=True,
                )
            except TypeError:
                pass
        return torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
    if name == "adafactor":
        return _build_adafactor(model.parameters(), cfg)
    if name in {"adam8bit", "adam_8bit", "8bit-adam"}:
        return _build_adam8bit(model.parameters(), cfg)
    if name in {"layerwisesgd", "lowmemsgd", "sgd"}:
        return LayerWiseSGD(model.parameters(), lr=cfg.learning_rate, momentum=0.9, weight_decay=cfg.weight_decay)
    return torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)


def collate_token_batch(batch: List[torch.Tensor], fixed_length: Optional[int] = None) -> Dict[str, torch.Tensor]:
    if fixed_length is not None and all(item.numel() >= fixed_length for item in batch):
        stacked = torch.stack([item[:fixed_length] for item in batch], dim=0)
        return {"input_ids": stacked[:, :-1], "target_ids": stacked[:, 1:]}
    max_len = max(item.numel() for item in batch)
    padded = torch.zeros((len(batch), max_len), dtype=torch.long)
    targets = torch.full((len(batch), max_len - 1), -100, dtype=torch.long)
    inputs = torch.zeros((len(batch), max_len - 1), dtype=torch.long)
    for i, item in enumerate(batch):
        padded[i, : item.numel()] = item
        inputs[i, : item.numel() - 1] = item[:-1]
        targets[i, : item.numel() - 1] = item[1:]
    return {"input_ids": inputs, "target_ids": targets}


def build_dataset(dataset_dir: str, vocab_size: int, max_seq_len: int, synthetic_samples: int = 128) -> Dataset:
    if dataset_dir:
        return PretokenizedShardDataset(dataset_dir, max_seq_len=max_seq_len)
    return SyntheticTokenDataset(vocab_size=vocab_size, max_seq_len=max_seq_len, num_samples=synthetic_samples)


def build_dataloader(dataset: Dataset, cfg: TrainStackConfig, shuffle: bool = True) -> DataLoader:
    kwargs = dict(
        batch_size=cfg.batch_size,
        shuffle=shuffle,
        num_workers=cfg.num_workers,
        pin_memory=cfg.pin_memory,
        persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
        collate_fn=functools.partial(collate_token_batch, fixed_length=cfg.max_seq_len + 1),
    )
    if cfg.num_workers > 0:
        kwargs["prefetch_factor"] = cfg.prefetch_factor
    return DataLoader(dataset, **kwargs)


def move_batch_to_device(batch: Dict[str, torch.Tensor], device: torch.device, non_blocking: bool = True) -> Dict[str, torch.Tensor]:
    return {key: value.to(device, non_blocking=non_blocking) for key, value in batch.items()}



def train_demo_steps(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    dataloader: DataLoader,
    device: torch.device,
    steps: int = 2,
    use_bf16: bool = True,
) -> Tuple[float, int]:
    model.train()
    total_loss = 0.0
    total_tokens = 0
    autocast_enabled = use_bf16 and device.type == "cuda"
    for step_idx, batch in enumerate(dataloader):
        if step_idx >= steps:
            break
        batch = move_batch_to_device(batch, device)
        optimizer.zero_grad(set_to_none=True)
        with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled):
            loss = model.training_loss(batch["input_ids"], batch["target_ids"])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += float(loss.detach().item())
        total_tokens += int((batch["target_ids"] != -100).sum().item())
    mean_loss = total_loss / max(1, steps)
    return mean_loss, total_tokens