| |
| |
|
|
| import json |
| import logging |
| from collections import namedtuple |
| from dataclasses import asdict, dataclass |
| from datetime import datetime, timezone |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Union |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.nn as nn |
| import wandb |
|
|
| from core.distributed import get_is_master |
|
|
| Scalar = Union[int, float] |
|
|
| logger = logging.getLogger() |
|
|
|
|
| @dataclass |
| class WandbArgs: |
| job_type: Optional[str] = None |
| dir: Optional[str] = None |
| project: Optional[str] = None |
| entity: Optional[str] = None |
| tags: Optional[List] = None |
| group: Optional[str] = None |
| name: Optional[str] = None |
| notes: Optional[str] = None |
| config_exclude_keys: Optional[List[str]] = None |
| config_include_keys: Optional[List[str]] = None |
| anonymous: Optional[str] = None |
| mode: Optional[str] = None |
| allow_val_change: Optional[bool] = None |
| resume: Optional[Union[bool, str]] = None |
| force: Optional[bool] = None |
| tensorboard: Optional[bool] = None |
| sync_tensorboard: Optional[bool] = None |
| monitor_gym: Optional[bool] = None |
| save_code: Optional[bool] = None |
| id: Optional[str] = None |
| fork_from: Optional[str] = None |
| resume_from: Optional[str] = None |
|
|
|
|
| @dataclass |
| class LoggingArgs: |
| freq: int = 10 |
| level: str = "INFO" |
| acc_freq: Optional[int] = None |
|
|
| wandb: Optional[WandbArgs] = None |
|
|
|
|
| class MetricLogger: |
| def __init__(self, outdir: Path, args: Optional[Any] = None): |
| self.outdir = outdir |
| self.jsonl_writer = None |
| self.args = args |
|
|
| def open(self): |
| if self.jsonl_writer is None: |
| self.jsonl_writer = open(self.outdir, "a") |
| if ( |
| self.args is not None |
| and self.args.logging.wandb is not None |
| and get_is_master() |
| ): |
| run = wandb.init( |
| config=asdict(self.args), |
| **asdict(self.args.logging.wandb), |
| ) |
|
|
| def log(self, metrics: Dict[str, Any]): |
| if ( |
| self.args is not None |
| and self.args.logging.wandb is not None |
| and (wandb.run is not None) |
| ): |
| wandb.log(metrics, step=metrics["global_step"]) |
|
|
| metrics.update({"created_at": datetime.now(timezone.utc).isoformat()}) |
| print(json.dumps(metrics), file=self.jsonl_writer, flush=True) |
|
|
| def close(self): |
| if self.jsonl_writer is not None: |
| self.jsonl_writer.close() |
| self.jsonl_writer = None |
|
|
| def __enter__(self): |
| self.open() |
| return self |
|
|
| def __exit__(self, exc_type, exc_value, traceback): |
| self.close() |
|
|
| def __del__(self): |
| self.close() |
|
|
|
|
| GPUMemStats = namedtuple( |
| "GPUMemStats", |
| [ |
| "max_active_gib", |
| "max_active_pct", |
| "max_reserved_gib", |
| "max_reserved_pct", |
| "num_alloc_retries", |
| "num_ooms", |
| "power_draw", |
| ], |
| ) |
|
|
|
|
| class GPUMemoryMonitor: |
| """ |
| Class to monitor GPU memory usage |
| """ |
|
|
| def __init__(self, device: str = "cuda:0"): |
| self.device = torch.device(device) |
| self.device_name = torch.cuda.get_device_name(self.device) |
| self.device_index = torch.cuda.current_device() |
| self.device_capacity = torch.cuda.get_device_properties( |
| self.device |
| ).total_memory |
| self.device_capacity_gib = self._to_gib(self.device_capacity) |
|
|
| |
| torch.cuda.reset_peak_memory_stats() |
| torch.cuda.empty_cache() |
|
|
| def _to_gib(self, memory_in_bytes): |
| |
| _gib_in_bytes = 1024 * 1024 * 1024 |
| memory_in_gib = memory_in_bytes / _gib_in_bytes |
| return memory_in_gib |
|
|
| def _to_pct(self, memory): |
| return 100 * memory / self.device_capacity |
|
|
| def get_peak_stats(self): |
| cuda_info = torch.cuda.memory_stats(self.device) |
|
|
| max_active = cuda_info["active_bytes.all.peak"] |
| max_active_gib = self._to_gib(max_active) |
| max_active_pct = self._to_pct(max_active) |
|
|
| max_reserved = cuda_info["reserved_bytes.all.peak"] |
| max_reserved_gib = self._to_gib(max_reserved) |
| max_reserved_pct = self._to_pct(max_reserved) |
|
|
| num_retries = cuda_info["num_alloc_retries"] |
| num_ooms = cuda_info["num_ooms"] |
| power_draw = torch.cuda.power_draw() |
|
|
| if num_retries > 0: |
| logger.warning(f"{num_retries} CUDA memory allocation retries.") |
| if num_ooms > 0: |
| logger.warning(f"{num_ooms} CUDA OOM errors thrown.") |
|
|
| return GPUMemStats( |
| max_active_gib, |
| max_active_pct, |
| max_reserved_gib, |
| max_reserved_pct, |
| num_retries, |
| num_ooms, |
| power_draw, |
| ) |
|
|
| def reset_peak_stats(self): |
| torch.cuda.reset_peak_memory_stats() |
| torch.cuda.reset_accumulated_memory_stats() |
|
|
| def __str__(self): |
| mem_stats = self.get_peak_stats() |
| display_str = f"{self.device_name} ({self.device_index}): {self.device_capacity_gib} GiB capacity, " |
| display_str += ( |
| f"{mem_stats.max_reserved_gib} GiB peak, {mem_stats.max_reserved_pct}% peak" |
| ) |
| return f"{display_str}" |
|
|
|
|
| def upload_train_to_wandb( |
| ckpt_dir, project="perception", entity="codegen-team", train=True, eval=True |
| ): |
| import json |
| from pathlib import Path |
|
|
| import wandb |
| from omegaconf import OmegaConf |
|
|
| cfg = OmegaConf.load(Path(ckpt_dir) / "config.yaml") |
| cfg = OmegaConf.to_container(cfg) |
|
|
| if train: |
| wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity) |
|
|
| with open(Path(ckpt_dir) / "metrics.jsonl") as f: |
| for l in f: |
| m = json.loads(l) |
| wandb.log(m, step=m["global_step"]) |
|
|
| wandb.finish() |
|
|
| if eval: |
| wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity) |
|
|
| with open(Path(ckpt_dir) / "metrics.eval.jsonl") as f: |
| for l in f: |
| m = json.loads(l) |
| wandb.log( |
| { |
| f"evals/{name.replace('/','.')}": value |
| for name, value in m.items() |
| if "/" in name |
| }, |
| step=m["global_step"], |
| ) |
|
|
| wandb.finish() |
|
|
|
|
| def get_num_params(model: nn.Module) -> int: |
| """ |
| Get the total model params |
| Args : only_trainable: whether to only count trainable params |
| """ |
| numel = {n: p.numel() for n, p in model.named_parameters()} |
| return sum(numel.values()) |
|
|
|
|
| def log_model_params(model): |
| frozen_params, unfrozen_params = [], [] |
| num_frozen_params, num_unfrozen_params = 0, 0 |
|
|
| def _fn(_model): |
| num_frozen_params, num_unfrozen_params = 0, 0 |
| for pname, pval in _model.named_parameters(): |
| if pval.requires_grad: |
| unfrozen_params.append(pname) |
| num_unfrozen_params += pval.numel() |
| else: |
| frozen_params.append(pname) |
| num_frozen_params += pval.numel() |
| return num_frozen_params, num_unfrozen_params |
|
|
| if isinstance(model, torch.nn.ModuleList): |
| for m in model: |
| _num_frozen_params, _num_unfrozen_params = _fn(m) |
| num_frozen_params += _num_frozen_params |
| num_unfrozen_params += _num_unfrozen_params |
| else: |
| num_frozen_params, num_unfrozen_params = _fn(model) |
|
|
| logger.info(f"Logging Trainable Parameters after first step.") |
| logger.debug(f"Frozen params: {frozen_params}") |
| logger.debug(f"Trainable params: {unfrozen_params}") |
| logger.info( |
| f"Params total: {num_frozen_params + num_unfrozen_params:,}, " |
| f"Learnable: {num_unfrozen_params:,}, " |
| f"Frozen: {num_frozen_params:,}." |
| ) |
|
|