| import math |
| import os |
| import signal |
| import sys |
| import time |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| |
|
|
| import infinity.utils.dist as dist |
| from infinity.utils import misc |
| import pdb |
|
|
| class NullCtx: |
| def __enter__(self): |
| pass |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| pass |
|
|
|
|
| def handle_timeout(signum, frame): |
| raise TimeoutError('took too long') |
|
|
|
|
| def per_param_clip_grad_norm_(parameters, thresh: float, stable=False, fp=None) -> (float, float): |
| skipped, max_grad = [], 0 |
| for pi, p in enumerate(parameters): |
| if p.grad is not None: |
| g = p.grad.data.norm(2).item() + 1e-7 |
| max_grad = max(max_grad, g) |
| clip_coef = thresh / g |
| if clip_coef < 1: |
| if stable and clip_coef < 0.2: |
| skipped.append(clip_coef) |
| p.grad.data.mul_(0) |
| else: |
| p.grad.data.mul_(clip_coef) |
| |
| |
| return 0 if len(skipped) == 0 else math.log10(max(min(skipped), 1e-7)), max_grad |
|
|
|
|
| class AmpOptimizer: |
| def __init__( |
| self, |
| model_name_3letters: str, mixed_precision: int, |
| optimizer: torch.optim.Optimizer, model_maybe_fsdp: Union[torch.nn.Module, FSDP], |
| r_accu: float, grad_clip: float, zero: int, |
| ): |
| self.enable_amp = mixed_precision > 0 |
| self.zero = zero |
| if self.enable_amp: |
| self.using_fp16_rather_bf16 = mixed_precision != 2 |
| self.max_sc = float(mixed_precision if mixed_precision > 128 else 32768) |
| |
| |
| self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=self.zero == 0) |
| if self.using_fp16_rather_bf16: |
| self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000) |
| else: |
| self.scaler = None |
| else: |
| self.using_fp16_rather_bf16 = True |
| self.amp_ctx = NullCtx() |
| self.scaler = None |
| |
| t = torch.zeros(dist.get_world_size()) |
| t[dist.get_rank()] = float(self.enable_amp) |
| dist.allreduce(t) |
| assert round(t.sum().item()) in {0, dist.get_world_size()}, f'enable_amp: {t}' |
| |
| t = torch.zeros(dist.get_world_size()) |
| t[dist.get_rank()] = float(self.using_fp16_rather_bf16) |
| dist.allreduce(t) |
| assert round(t.sum().item()) in {0, dist.get_world_size()}, f'using_fp16_rather_bf16: {t}' |
| |
| self.model_name_3letters = model_name_3letters |
| self.optimizer, self.model_maybe_fsdp = optimizer, model_maybe_fsdp |
| self.r_accu = r_accu |
| |
| self.paras = self.names = ... |
| |
| self.grad_clip, self.grad_clip_we = grad_clip, 0 |
| if self.grad_clip > 100: |
| self.grad_clip %= 100 |
| self.per_param = True |
| else: |
| self.per_param = False |
| self.per_param = False |
| |
| self.early_clipping = grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm') |
| self.late_clipping = grad_clip > 0 and hasattr(optimizer, 'global_grad_norm') |
| |
| self.fp = None |
| self.last_orig_norm: torch.Tensor = torch.tensor(0.1) |
| |
| @torch.no_grad() |
| def log_param(self, ep: int): |
| if self.zero == 0: |
| for name, values in get_param_for_log(self.model_name_3letters, self.model_maybe_fsdp.named_parameters()).items(): |
| values: List[float] |
| if len(values) == 1: |
| values.append(values[0]) |
| else: |
| ... |
| |
| |
| |
| def backward_clip_step( |
| self, ep: int, it: int, g_it: int, stepping: bool, logging_params: bool, loss: torch.Tensor, clip_decay_ratio=1, stable=False, |
| ) -> Tuple[torch.Tensor, Optional[float]]: |
| |
| loss = loss.mul(self.r_accu) |
| orig_norm = scaler_sc = None |
| |
| |
| if self.scaler is not None: |
| self.scaler.scale(loss).backward(retain_graph=False, create_graph=False) |
| else: |
| loss.backward(retain_graph=False, create_graph=False) |
| |
| |
| |
| |
| |
| |
| |
| if stepping: |
| if self.scaler is not None: self.scaler.unscale_(self.optimizer) |
| |
| |
| skipped, orig_norm = 0, self.last_orig_norm |
| |
| if self.fp is not None: |
| if g_it % 10 == 0: self.fp.seek(0); self.fp.truncate(0) |
| self.fp.write(f'<ep{ep} it{it} {g_it}>\n'); self.fp.flush() |
| if self.early_clipping: |
| c = self.grad_clip * clip_decay_ratio |
| if self.zero: |
| orig_norm: Optional[torch.Tensor] = self.model_maybe_fsdp.clip_grad_norm_(c) |
| else: |
| orig_norm: Optional[torch.Tensor] = torch.nn.utils.clip_grad_norm_(self.model_maybe_fsdp.parameters(), c) |
|
|
| |
| |
| |
| |
| if self.scaler is not None: |
| self.scaler: torch.cuda.amp.GradScaler |
| if self.zero: |
| |
| |
| for optimizer_state in self.scaler._per_optimizer_states.values(): |
| for t in optimizer_state['found_inf_per_device'].values(): |
| dist.allreduce(t) |
| |
| self.scaler.step(self.optimizer) |
| scaler_sc: Optional[float] = self.scaler.get_scale() |
| if scaler_sc > self.max_sc: |
| |
| self.scaler.update(new_scale=self.max_sc) |
| else: |
| self.scaler.update() |
| try: |
| scaler_sc = float(math.log2(scaler_sc)) |
| except Exception as e: |
| print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True) |
| time.sleep(1) |
| print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True) |
| raise e |
| else: |
| self.optimizer.step() |
| |
| if self.late_clipping: |
| orig_norm: Optional[torch.Tensor] = self.optimizer.global_grad_norm |
| self.last_orig_norm = orig_norm |
| |
| return orig_norm, scaler_sc |
| |
| def state_dict(self): |
| return { |
| 'optimizer': self.optimizer.state_dict() |
| } if self.scaler is None else { |
| 'scaler': self.scaler.state_dict(), |
| 'optimizer': self.optimizer.state_dict() |
| } |
| |
| def load_state_dict(self, state, strict=True): |
| if self.scaler is not None: |
| try: self.scaler.load_state_dict(state['scaler']) |
| except Exception as e: print(f'[fp16 load_state_dict err] {e}') |
| self.optimizer.load_state_dict(state['optimizer']) |
|
|
| class AmpOptimizerVAE: |
| def __init__( |
| self, |
| model_name_3letters: str, mixed_precision: int, |
| optimizer: torch.optim.Optimizer, model_maybe_fsdp: Union[torch.nn.Module, FSDP], |
| r_accu: float, grad_clip: float, zero: int, |
| vae_local, |
| ): |
| self.enable_amp = mixed_precision > 0 |
| self.zero = zero |
| if self.enable_amp: |
| self.using_fp16_rather_bf16 = mixed_precision != 2 |
| self.max_sc = float(mixed_precision if mixed_precision > 128 else 32768) |
| |
| |
| self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=self.zero == 0) |
| if self.using_fp16_rather_bf16: |
| self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000) |
| else: |
| self.scaler = None |
| else: |
| self.using_fp16_rather_bf16 = True |
| self.amp_ctx = NullCtx() |
| self.scaler = None |
| |
| t = torch.zeros(dist.get_world_size()) |
| t[dist.get_rank()] = float(self.enable_amp) |
| dist.allreduce(t) |
| assert round(t.sum().item()) in {0, dist.get_world_size()}, f'enable_amp: {t}' |
| |
| t = torch.zeros(dist.get_world_size()) |
| t[dist.get_rank()] = float(self.using_fp16_rather_bf16) |
| dist.allreduce(t) |
| assert round(t.sum().item()) in {0, dist.get_world_size()}, f'using_fp16_rather_bf16: {t}' |
| |
| self.model_name_3letters = model_name_3letters |
| self.optimizer, self.model_maybe_fsdp = optimizer, model_maybe_fsdp |
| self.r_accu = r_accu |
| |
| self.paras = self.names = ... |
| |
| self.grad_clip, self.grad_clip_we = grad_clip, 0 |
| if self.grad_clip > 100: |
| self.grad_clip %= 100 |
| self.per_param = True |
| else: |
| self.per_param = False |
| self.per_param = False |
| |
| self.early_clipping = grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm') |
| self.late_clipping = grad_clip > 0 and hasattr(optimizer, 'global_grad_norm') |
| |
| self.fp = None |
| self.last_orig_norm: torch.Tensor = torch.tensor(0.1) |
|
|
| self.vae_local = vae_local |
| |
| @torch.no_grad() |
| def log_param(self, ep: int): |
| if self.zero == 0: |
| for name, values in get_param_for_log(self.model_name_3letters, self.model_maybe_fsdp.named_parameters()).items(): |
| values: List[float] |
| if len(values) == 1: |
| values.append(values[0]) |
| else: |
| ... |
| |
| |
| |
| def backward_clip_step( |
| self, ep: int, it: int, g_it: int, stepping: bool, logging_params: bool, loss: torch.Tensor, clip_decay_ratio=1, stable=False, |
| ) -> Tuple[torch.Tensor, Optional[float]]: |
| |
| loss = loss.mul(self.r_accu) |
| orig_norm = scaler_sc = None |
| |
| |
| if self.scaler is not None: |
| self.scaler.scale(loss).backward(retain_graph=False, create_graph=False) |
| else: |
| loss.backward(retain_graph=False, create_graph=False) |
| |
| |
| print(f"vae_encoder {torch.sum(self.vae_local.encoder.down[0].block[0].conv1.conv.lora_down.weight.grad)}") |
| print(f"infinity {self.gpt_ddp.block_chunks[0].module.module[0].ca.mat_kv.lora_up.weight.requires_grad}") |
| print(f"infinity {torch.sum(self.gpt_ddp.block_chunks[0].module.module[0].ca.mat_kv.lora_up.weight.grad)}") |
| |
| |
| if stepping: |
| if self.scaler is not None: self.scaler.unscale_(self.optimizer) |
| |
| |
| skipped, orig_norm = 0, self.last_orig_norm |
| |
| if self.fp is not None: |
| if g_it % 10 == 0: self.fp.seek(0); self.fp.truncate(0) |
| self.fp.write(f'<ep{ep} it{it} {g_it}>\n'); self.fp.flush() |
| if self.early_clipping: |
| c = self.grad_clip * clip_decay_ratio |
| if self.zero: |
| orig_norm: Optional[torch.Tensor] = self.model_maybe_fsdp.clip_grad_norm_(c) |
| else: |
| orig_norm: Optional[torch.Tensor] = torch.nn.utils.clip_grad_norm_(self.model_maybe_fsdp.parameters(), c) |
|
|
| orig_norm_vae = torch.nn.utils.clip_grad_norm_(self.vae_local.parameters(), c) |
| print(f'orig_nom {orig_norm} orig_norm_vae{orig_norm_vae}') |
| |
| |
| if self.scaler is not None: |
| self.scaler: torch.cuda.amp.GradScaler |
| if self.zero: |
| |
| |
| for optimizer_state in self.scaler._per_optimizer_states.values(): |
| for t in optimizer_state['found_inf_per_device'].values(): |
| dist.allreduce(t) |
| |
| self.scaler.step(self.optimizer) |
| scaler_sc: Optional[float] = self.scaler.get_scale() |
| if scaler_sc > self.max_sc: |
| |
| self.scaler.update(new_scale=self.max_sc) |
| else: |
| self.scaler.update() |
| try: |
| scaler_sc = float(math.log2(scaler_sc)) |
| except Exception as e: |
| print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True) |
| time.sleep(1) |
| print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True) |
| raise e |
| else: |
| self.optimizer.step() |
| |
| if self.late_clipping: |
| orig_norm: Optional[torch.Tensor] = self.optimizer.global_grad_norm |
| self.last_orig_norm = orig_norm |
| |
| return orig_norm, scaler_sc |
| |
| def state_dict(self): |
| return { |
| 'optimizer': self.optimizer.state_dict() |
| } if self.scaler is None else { |
| 'scaler': self.scaler.state_dict(), |
| 'optimizer': self.optimizer.state_dict() |
| } |
| |
| def load_state_dict(self, state, strict=True): |
| if self.scaler is not None: |
| try: self.scaler.load_state_dict(state['scaler']) |
| except Exception as e: print(f'[fp16 load_state_dict err] {e}') |
| self.optimizer.load_state_dict(state['optimizer']) |
|
|