| import torch |
| import datetime |
| import time |
| import torch.distributed as dist |
| import yaml |
| import os |
|
|
| class MetricLogger: |
| """Metric logger for training""" |
| def __init__(self, delimiter="\t"): |
| self.meters = {} |
| self.delimiter = delimiter |
|
|
| def update(self, **kwargs): |
| for k, v in kwargs.items(): |
| if isinstance(v, torch.Tensor): |
| v = v.item() |
| if k not in self.meters: |
| self.meters[k] = SmoothedValue() |
| self.meters[k].update(v) |
|
|
| def __str__(self): |
| loss_str = [] |
| for name, meter in self.meters.items(): |
| loss_str.append(f"{name}: {meter}") |
| return self.delimiter.join(loss_str) |
|
|
| def synchronize_between_processes(self): |
| for meter in self.meters.values(): |
| meter.synchronize_between_processes() |
|
|
| def log_every(self, iterable, print_freq, header=None): |
| i = 0 |
| if not header: |
| header = '' |
| start_time = time.time() |
| end = time.time() |
| iter_time = SmoothedValue(fmt='{avg:.4f}') |
| data_time = SmoothedValue(fmt='{avg:.4f}') |
| space_fmt = ':' + str(len(str(len(iterable)))) + 'd' |
| log_msg = [ |
| header, |
| '[{0' + space_fmt + '}/{1}]', |
| 'eta: {eta}', |
| '{meters}', |
| 'time: {time}', |
| 'data: {data}' |
| ] |
| log_msg = self.delimiter.join(log_msg) |
| for obj in iterable: |
| data_time.update(time.time() - end) |
| yield obj |
| iter_time.update(time.time() - end) |
| if i % print_freq == 0 or i == len(iterable) - 1: |
| eta_seconds = iter_time.global_avg * (len(iterable) - i) |
| eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) |
| if torch.cuda.is_available() and dist.get_rank() == 0: |
| print(log_msg.format( |
| i, len(iterable), eta=eta_string, |
| meters=str(self), |
| time=str(iter_time), data=str(data_time))) |
| i += 1 |
| end = time.time() |
| total_time = time.time() - start_time |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
| print(f'{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)') |
|
|
|
|
| class SmoothedValue: |
| """Track a series of values and provide access to smoothed values""" |
| def __init__(self, window_size=20, fmt=None): |
| if fmt is None: |
| fmt = "{median:.4f} ({global_avg:.4f})" |
| self.deque = [] |
| self.total = 0.0 |
| self.count = 0 |
| self.fmt = fmt |
| self.window_size = window_size |
|
|
| def update(self, value, n=1): |
| self.deque.append(value) |
| if len(self.deque) > self.window_size: |
| self.deque.pop(0) |
| self.count += n |
| self.total += value * n |
|
|
| def synchronize_between_processes(self): |
| """Synchronize across all processes""" |
| if not dist.is_available() or not dist.is_initialized(): |
| return |
| t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') |
| dist.barrier() |
| dist.all_reduce(t) |
| t = t.tolist() |
| self.count = int(t[0]) |
| self.total = t[1] |
|
|
| @property |
| def median(self): |
| d = sorted(self.deque) |
| n = len(d) |
| if n == 0: |
| return 0 |
| if n % 2 == 0: |
| return (d[n // 2 - 1] + d[n // 2]) / 2 |
| return d[n // 2] |
|
|
| @property |
| def avg(self): |
| if len(self.deque) == 0: |
| return 0 |
| return sum(self.deque) / len(self.deque) |
|
|
| @property |
| def global_avg(self): |
| if self.count == 0: |
| return 0 |
| return self.total / self.count |
|
|
| def __str__(self): |
| return self.fmt.format( |
| median=self.median, |
| avg=self.avg, |
| global_avg=self.global_avg, |
| max=max(self.deque) if len(self.deque) > 0 else 0, |
| value=self.deque[-1] if len(self.deque) > 0 else 0 |
| ) |
| |
|
|
|
|
| def load_config(config_path): |
| """Load configuration from YAML file""" |
| with open(config_path, 'r') as f: |
| config = yaml.safe_load(f) |
| return config |
|
|
|
|
| def log_to_file(log_file, message): |
| """Write message to log file""" |
| if log_file is not None: |
| with open(log_file, 'a') as f: |
| timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
| f.write(f"[{timestamp}] {message}\n") |
| f.flush() |
|
|
|
|
| def count_parameters(model, verbose=True): |
| """Count model parameters""" |
| def count_params(module): |
| return sum(p.numel() for p in module.parameters() if p.requires_grad) |
|
|
| def format_number(num): |
| if num >= 1e9: |
| return f"{num/1e9:.2f}B" |
| elif num >= 1e6: |
| return f"{num/1e6:.2f}M" |
| elif num >= 1e3: |
| return f"{num/1e3:.2f}K" |
| else: |
| return str(num) |
|
|
| |
| if hasattr(model, 'module'): |
| model = model.module |
|
|
| total_params = count_params(model) |
|
|
| if verbose: |
| print("\n" + "="*80) |
| print("Model Parameter Statistics") |
| print("="*80) |
|
|
| |
| encoder_params = 0 |
| for name in ['patch_embed', 'blocks', 'encoder_norm']: |
| if hasattr(model, name): |
| module = getattr(model, name) |
| params = count_params(module) |
| encoder_params += params |
| print(f"{name:.<35} {params:>15,} ({format_number(params):>8})") |
|
|
| |
| if hasattr(model, 'head'): |
| head_params = count_params(model.head) |
| print(f"{'Classification/Regression Head':.<35} {head_params:>15,} ({format_number(head_params):>8})") |
|
|
| print("\n" + "="*80) |
| print(f"{'Encoder Parameters':.<35} {encoder_params:>15,} ({format_number(encoder_params):>8})") |
| print(f"{'TOTAL TRAINABLE PARAMETERS':.<35} {total_params:>15,} ({format_number(total_params):>8})") |
| print("="*80 + "\n") |
|
|
| return total_params |
|
|
|
|
|
|
| def save_checkpoint(state, is_best, checkpoint_dir, filename='checkpoint.pth'): |
| """Save checkpoint""" |
| checkpoint_path = os.path.join(checkpoint_dir, filename) |
| torch.save(state, checkpoint_path) |
| if is_best: |
| best_path = os.path.join(checkpoint_dir, 'checkpoint_best.pth') |
| torch.save(state, best_path) |
|
|
|
|
| def load_checkpoint(checkpoint_path, model, optimizer, scheduler, scaler=None): |
| """Load checkpoint""" |
| if not os.path.isfile(checkpoint_path): |
| print(f"No checkpoint found at '{checkpoint_path}'") |
| return 0, 0.0, 0.0 |
|
|
| print(f"Loading checkpoint '{checkpoint_path}'") |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
| start_epoch = checkpoint['epoch'] |
| best_metric = checkpoint.get('best_metric', 0.0) |
| best_loss = checkpoint.get('best_loss', float('inf')) |
|
|
| model.load_state_dict(checkpoint['model_state_dict']) |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
|
|
| if scaler is not None and 'scaler_state_dict' in checkpoint: |
| scaler.load_state_dict(checkpoint['scaler_state_dict']) |
|
|
| print(f"Loaded checkpoint from epoch {start_epoch}") |
| return start_epoch, best_metric, best_loss |
|
|
|
|
|
|
| class LabelScaler: |
| def __init__(self, mean, std): |
| self.mean = mean |
| self.std = std |
|
|
| def transform(self, labels): |
| """标准化: (y - mean) / std""" |
| return (labels - self.mean) / self.std |