| |
| import os |
| import math |
| import torch |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from torch.utils.data import DataLoader, Sampler |
| from torch.utils.data.distributed import DistributedSampler |
| from torch.optim.lr_scheduler import LambdaLR |
| from collections import defaultdict |
| from diffusers import UNet2DConditionModel, AutoencoderKL |
| from accelerate import Accelerator |
| from datasets import load_from_disk |
| from tqdm import tqdm |
| from PIL import Image, ImageOps |
| import wandb |
| import random |
| import gc |
| from accelerate.state import DistributedType |
| from torch.distributed import broadcast_object_list |
| from torch.utils.checkpoint import checkpoint |
| from diffusers.models.attention_processor import AttnProcessor2_0 |
| from datetime import datetime |
| import bitsandbytes as bnb |
| import torch.nn.functional as F |
| from collections import deque |
| from transformers import AutoTokenizer, AutoModel |
|
|
| |
| ds_path = "/workspace/sdxs/datasets/768" |
| project = "sdxs_08b" |
| batch_size = 64 |
| base_learning_rate = 4e-5 |
| min_learning_rate = 9e-6 |
| num_epochs = 1 |
| sample_interval_share = 10 |
| cfg_dropout = 0.75 |
| max_length = 192 |
| use_wandb = False |
| use_comet_ml = False |
| save_model = False |
| use_decay = True |
| fbp = True |
| optimizer_type = "adam8bit" |
| torch_compile = False |
| unet_gradient = True |
| fixed_seed = False |
| shuffle = True |
| comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r" |
| comet_ml_workspace = "recoilme" |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.backends.cuda.enable_mem_efficient_sdp(False) |
| dtype = torch.float32 |
| save_barrier = 1.01 |
| warmup_percent = 0.01 |
| percentile_clipping = 95 |
| betta2 = 0.995 |
| eps = 1e-7 |
| clip_grad_norm = 1.0 |
| limit = 0 |
| checkpoints_folder = "" |
| mixed_precision = "no" |
| gradient_accumulation_steps = 1 |
|
|
| accelerator = Accelerator( |
| mixed_precision=mixed_precision, |
| gradient_accumulation_steps=gradient_accumulation_steps |
| ) |
| device = accelerator.device |
|
|
| |
| n_diffusion_steps = 40 |
| samples_to_generate = 12 |
| guidance_scale = 4 |
|
|
| |
| generated_folder = "samples" |
| os.makedirs(generated_folder, exist_ok=True) |
|
|
| |
| current_date = datetime.now() |
| seed = int(current_date.strftime("%Y%m%d")) |
| if fixed_seed: |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
| |
| lora_name = "" |
| lora_rank = 32 |
| lora_alpha = 64 |
|
|
| print("init") |
|
|
| loss_ratios = { |
| "mse": 0.9, |
| "mae": 0.1, |
| } |
| median_coeff_steps = 256 |
|
|
| |
| class MedianLossNormalizer: |
| def __init__(self, desired_ratios: dict, window_steps: int): |
| |
| |
| |
| self.ratios = {k: float(v) for k, v in desired_ratios.items()} |
| self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()} |
| self.window = window_steps |
|
|
| def update_and_total(self, losses: dict): |
| """ |
| losses: dict ключ->тензор (значения лоссов) |
| Поведение: |
| - буферим ABS(l) только для активных (ratio>0) лоссов |
| - coeff = ratio / median(abs(loss)) |
| - total = sum(coeff * loss) по активным лоссам |
| CHANGED: буферим abs() — чтобы медиана была положительной и не ломала деление. |
| """ |
| |
| for k, v in losses.items(): |
| if k in self.buffers and self.ratios.get(k, 0) > 0: |
| val = v.detach().abs().mean().cpu().item() |
| self.buffers[k].append(val) |
| |
|
|
| meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers} |
| coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios} |
|
|
| |
| total = sum(coeffs[k] * losses[k] for k in coeffs if self.ratios.get(k, 0) > 0) |
| return total, coeffs, meds |
|
|
| |
| normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps) |
|
|
| |
| if accelerator.is_main_process: |
| if use_wandb: |
| wandb.init(project=project+lora_name, config={ |
| "batch_size": batch_size, |
| "base_learning_rate": base_learning_rate, |
| "num_epochs": num_epochs, |
| "optimizer_type": optimizer_type, |
| }) |
| if use_comet_ml: |
| from comet_ml import Experiment |
| comet_experiment = Experiment( |
| api_key=comet_ml_api_key, |
| project_name=project, |
| workspace=comet_ml_workspace |
| ) |
| hyper_params = { |
| "batch_size": batch_size, |
| "base_learning_rate": base_learning_rate, |
| "num_epochs": num_epochs, |
| } |
| comet_experiment.log_parameters(hyper_params) |
|
|
| |
| torch.backends.cuda.enable_flash_sdp(True) |
|
|
| |
| vae = AutoencoderKL.from_pretrained("vae1x", torch_dtype=dtype).to("cpu").eval() |
| |
| tokenizer = AutoTokenizer.from_pretrained("tokenizer") |
| text_model = AutoModel.from_pretrained("text_encoder").to(device).eval() |
|
|
| |
| def encode_texts(texts, max_length=max_length): |
| |
| if texts is None: |
| |
| |
| pass |
|
|
| with torch.no_grad(): |
| if isinstance(texts, str): |
| texts = [texts] |
|
|
| for i, prompt_item in enumerate(texts): |
| messages = [ |
| {"role": "user", "content": prompt_item}, |
| ] |
| prompt_item = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| |
| ) |
| |
| texts[i] = prompt_item |
| |
| toks = tokenizer( |
| texts, |
| return_tensors="pt", |
| padding="max_length", |
| truncation=True, |
| max_length=max_length |
| ).to(device) |
| |
| outs = text_model(**toks, output_hidden_states=True, return_dict=True) |
| |
| |
| hidden = outs.hidden_states[-2] |
| |
| |
| attention_mask = toks["attention_mask"] |
| |
| |
| sequence_lengths = attention_mask.sum(dim=1) - 1 |
| batch_size = hidden.shape[0] |
| pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths] |
|
|
| |
| |
| |
| pooled_expanded = pooled.unsqueeze(1) |
| |
| |
| |
| |
| new_encoder_hidden_states = torch.cat([pooled_expanded, hidden], dim=1) |
| |
| |
| |
| |
| new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1) |
| |
| return new_encoder_hidden_states, new_attention_mask |
|
|
| shift_factor = getattr(vae.config, "shift_factor", 0.0) |
| if shift_factor is None: shift_factor = 0.0 |
| scaling_factor = getattr(vae.config, "scaling_factor", 1.0) |
| if scaling_factor is None: scaling_factor = 1.0 |
|
|
| from diffusers import FlowMatchEulerDiscreteScheduler |
| num_train_timesteps = 1000 |
| scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=num_train_timesteps) |
|
|
| class DistributedResolutionBatchSampler(Sampler): |
| def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True): |
| self.dataset = dataset |
| self.batch_size = max(1, batch_size // num_replicas) |
| self.num_replicas = num_replicas |
| self.rank = rank |
| self.shuffle = shuffle |
| self.drop_last = drop_last |
| self.epoch = 0 |
| |
| try: |
| widths = np.array(dataset["width"]) |
| heights = np.array(dataset["height"]) |
| except KeyError: |
| widths = np.zeros(len(dataset)) |
| heights = np.zeros(len(dataset)) |
| |
| self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0) |
| self.size_groups = {} |
| for w, h in self.size_keys: |
| mask = (widths == w) & (heights == h) |
| self.size_groups[(w, h)] = np.where(mask)[0] |
| |
| self.group_num_batches = {} |
| total_batches = 0 |
| for size, indices in self.size_groups.items(): |
| num_full_batches = len(indices) // (self.batch_size * self.num_replicas) |
| self.group_num_batches[size] = num_full_batches |
| total_batches += num_full_batches |
| |
| self.num_batches = (total_batches // self.num_replicas) * self.num_replicas |
| |
| def __iter__(self): |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| all_batches = [] |
| rng = np.random.RandomState(self.epoch) |
| |
| for size, indices in self.size_groups.items(): |
| indices = indices.copy() |
| if self.shuffle: |
| rng.shuffle(indices) |
| num_full_batches = self.group_num_batches[size] |
| if num_full_batches == 0: |
| continue |
| valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas] |
| batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas) |
| start_idx = self.rank * self.batch_size |
| end_idx = start_idx + self.batch_size |
| gpu_batches = batches[:, start_idx:end_idx] |
| all_batches.extend(gpu_batches) |
| |
| if self.shuffle: |
| rng.shuffle(all_batches) |
| accelerator.wait_for_everyone() |
| return iter(all_batches) |
|
|
| def __len__(self): |
| return self.num_batches |
|
|
| def set_epoch(self, epoch): |
| self.epoch = epoch |
|
|
| |
| def get_fixed_samples_by_resolution(dataset, samples_per_group=1): |
| size_groups = defaultdict(list) |
| try: |
| widths = dataset["width"] |
| heights = dataset["height"] |
| except KeyError: |
| widths = [0] * len(dataset) |
| heights = [0] * len(dataset) |
| for i, (w, h) in enumerate(zip(widths, heights)): |
| size = (w, h) |
| size_groups[size].append(i) |
| |
| fixed_samples = {} |
| for size, indices in size_groups.items(): |
| n_samples = min(samples_per_group, len(indices)) |
| if len(size_groups)==1: |
| n_samples = samples_to_generate |
| if n_samples == 0: |
| continue |
| sample_indices = random.sample(indices, n_samples) |
| samples_data = [dataset[idx] for idx in sample_indices] |
| |
| latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype) |
| texts = [item["text"] for item in samples_data] |
| |
| |
| embeddings, masks = encode_texts(texts) |
| |
| fixed_samples[size] = (latents, embeddings, masks, texts) |
| |
| print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям") |
| return fixed_samples |
|
|
| if limit > 0: |
| dataset = load_from_disk(ds_path).select(range(limit)) |
| else: |
| dataset = load_from_disk(ds_path) |
|
|
| dataset = dataset.filter( |
| lambda x: [not (path.startswith("/workspace/ds/animesfw") or path.startswith("/workspace/ds/d4/animesfw")) for path in x["image_path"]], |
| batched=True, |
| batch_size=10000, |
| num_proc=8 |
| ) |
| print(f"Осталось примеров после фильтрации: {len(dataset)}") |
|
|
| |
| def collate_fn_simple(batch): |
| |
| latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device, dtype=dtype) |
| |
| |
| raw_texts = [item["text"] for item in batch] |
| texts = [ |
| "" if t.lower().startswith("zero") |
| else "" if random.random() < cfg_dropout |
| else t[1:].lstrip() if t.startswith(".") |
| else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip() |
| for t in raw_texts |
| ] |
| |
| |
| |
| embeddings, attention_mask = encode_texts(texts) |
| |
| |
| attention_mask = attention_mask.to(dtype=torch.int64) |
|
|
| return latents, embeddings, attention_mask |
|
|
| batch_sampler = DistributedResolutionBatchSampler( |
| dataset=dataset, |
| batch_size=batch_size, |
| num_replicas=accelerator.num_processes, |
| rank=accelerator.process_index, |
| shuffle=shuffle |
| ) |
|
|
| dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple) |
| if accelerator.is_main_process: |
| print("Total samples", len(dataloader)) |
| dataloader = accelerator.prepare(dataloader) |
|
|
| start_epoch = 0 |
| global_step = 0 |
| total_training_steps = (len(dataloader) * num_epochs) |
| world_size = accelerator.state.num_processes |
|
|
| |
| latest_checkpoint = os.path.join(checkpoints_folder, project) |
| if os.path.isdir(latest_checkpoint): |
| print("Загружаем UNet из чекпоинта:", latest_checkpoint) |
| unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype) |
| if unet_gradient: |
| unet.enable_gradient_checkpointing() |
| unet.set_use_memory_efficient_attention_xformers(False) |
| try: |
| unet.set_attn_processor(AttnProcessor2_0()) |
| except Exception as e: |
| print(f"Ошибка при включении SDPA: {e}") |
| unet.set_use_memory_efficient_attention_xformers(True) |
| else: |
| raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}") |
|
|
| if lora_name: |
| |
| pass |
|
|
| |
| if lora_name: |
| trainable_params = [p for p in unet.parameters() if p.requires_grad] |
| else: |
| if fbp: |
| trainable_params = list(unet.parameters()) |
|
|
| def create_optimizer(name, params): |
| if name == "adam8bit": |
| return bnb.optim.AdamW8bit( |
| params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01, |
| percentile_clipping=percentile_clipping |
| ) |
| elif name == "adam": |
| return torch.optim.AdamW( |
| params, lr=base_learning_rate, betas=(0.9, betta2), eps=1e-8, weight_decay=0.01 |
| ) |
| elif name == "muon": |
| from muon import MuonWithAuxAdam |
| trainable_params = [p for p in params if p.requires_grad] |
| hidden_weights = [p for p in trainable_params if p.ndim >= 2] |
| hidden_gains_biases = [p for p in trainable_params if p.ndim < 2] |
| |
| param_groups = [ |
| dict(params=hidden_weights, use_muon=True, |
| lr=1e-3, weight_decay=1e-4), |
| dict(params=hidden_gains_biases, use_muon=False, |
| lr=1e-4, betas=(0.9, 0.95), weight_decay=1e-4), |
| ] |
| optimizer = MuonWithAuxAdam(param_groups) |
| from snooc import SnooC |
| return SnooC(optimizer) |
| else: |
| raise ValueError(f"Unknown optimizer: {name}") |
|
|
| if fbp: |
| optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params} |
| def optimizer_hook(param): |
| optimizer_dict[param].step() |
| optimizer_dict[param].zero_grad(set_to_none=True) |
| for param in trainable_params: |
| param.register_post_accumulate_grad_hook(optimizer_hook) |
| unet, optimizer = accelerator.prepare(unet, optimizer_dict) |
| else: |
| optimizer = create_optimizer(optimizer_type, unet.parameters()) |
| def lr_schedule(step): |
| x = step / (total_training_steps * world_size) |
| warmup = warmup_percent |
| if not use_decay: |
| return base_learning_rate |
| if x < warmup: |
| return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup) |
| decay_ratio = (x - warmup) / (1 - warmup) |
| return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \ |
| (1 + math.cos(math.pi * decay_ratio)) |
| lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate) |
| unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler) |
|
|
| if torch_compile: |
| print("compiling") |
| unet = torch.compile(unet) |
| print("compiling - ok") |
|
|
| |
| fixed_samples = get_fixed_samples_by_resolution(dataset) |
|
|
| |
| def get_negative_embedding(neg_prompt="", batch_size=1): |
| if not neg_prompt: |
| hidden_dim = 1024 |
| seq_len = max_length |
| empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device) |
| empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device) |
| return empty_emb, empty_mask |
|
|
| uncond_emb, uncond_mask = encode_texts([neg_prompt]) |
| uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1) |
| uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1) |
|
|
| return uncond_emb, uncond_mask |
| |
| |
| uncond_emb, uncond_mask = get_negative_embedding("low quality") |
|
|
| |
| @torch.compiler.disable() |
| @torch.no_grad() |
| def generate_and_save_samples(fixed_samples_cpu, uncond_data, step): |
| uncond_emb, uncond_mask = uncond_data |
| |
| original_model = None |
| try: |
| if not torch_compile: |
| original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval() |
| else: |
| original_model = unet.eval() |
|
|
| vae.to(device=device).eval() |
| |
| all_generated_images = [] |
| all_captions = [] |
| |
| |
| for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items(): |
| width, height = size |
| sample_latents = sample_latents.to(dtype=dtype, device=device) |
| sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device) |
| sample_mask = sample_mask.to(device=device) |
| |
| latents = torch.randn( |
| sample_latents.shape, |
| device=device, |
| dtype=sample_latents.dtype, |
| generator=torch.Generator(device=device).manual_seed(seed) |
| ) |
| |
| scheduler.set_timesteps(n_diffusion_steps, device=device) |
| |
| for t in scheduler.timesteps: |
| if guidance_scale != 1: |
| latent_model_input = torch.cat([latents, latents], dim=0) |
| |
| |
| |
| curr_batch_size = sample_text_embeddings.shape[0] |
| seq_len = sample_text_embeddings.shape[1] |
| hidden_dim = sample_text_embeddings.shape[2] |
| |
| neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1) |
| text_embeddings_batch = torch.cat([neg_emb_batch, sample_text_embeddings], dim=0) |
| |
| |
| neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1) |
| attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0) |
|
|
| else: |
| latent_model_input = latents |
| text_embeddings_batch = sample_text_embeddings |
| attention_mask_batch = sample_mask |
|
|
| |
| model_out = original_model( |
| latent_model_input, |
| t, |
| encoder_hidden_states=text_embeddings_batch, |
| encoder_attention_mask=attention_mask_batch, |
| ) |
| flow = getattr(model_out, "sample", model_out) |
| |
| if guidance_scale != 1: |
| flow_uncond, flow_cond = flow.chunk(2) |
| flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond) |
| |
| latents = scheduler.step(flow, t, latents).prev_sample |
| |
| current_latents = latents |
| if step==0: |
| current_latents = sample_latents |
|
|
| latent_for_vae = current_latents.detach() / scaling_factor + shift_factor |
| decoded = vae.decode(latent_for_vae.to(torch.float32)).sample |
| decoded_fp32 = decoded.to(torch.float32) |
| |
| for img_idx, img_tensor in enumerate(decoded_fp32): |
| img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy() |
| img = img.transpose(1, 2, 0) |
| |
| if np.isnan(img).any(): |
| print("NaNs found, saving stopped! Step:", step) |
| pil_img = Image.fromarray((img * 255).astype("uint8")) |
| |
| max_w_overall = max(s[0] for s in fixed_samples_cpu.keys()) |
| max_h_overall = max(s[1] for s in fixed_samples_cpu.keys()) |
| max_w_overall = max(255, max_w_overall) |
| max_h_overall = max(255, max_h_overall) |
| |
| padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white') |
| all_generated_images.append(padded_img) |
|
|
| caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else "" |
| all_captions.append(caption_text) |
| |
| sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg" |
| pil_img.save(sample_path, "JPEG", quality=96) |
| |
| if use_wandb and accelerator.is_main_process: |
| wandb_images = [ |
| wandb.Image(img, caption=f"{all_captions[i]}") |
| for i, img in enumerate(all_generated_images) |
| ] |
| wandb.log({"generated_images": wandb_images}) |
| if use_comet_ml and accelerator.is_main_process: |
| for i, img in enumerate(all_generated_images): |
| comet_experiment.log_image( |
| image_data=img, |
| name=f"step_{step}_img_{i}", |
| step=step, |
| metadata={"caption": all_captions[i]} |
| ) |
| finally: |
| vae.to("cpu") |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| |
| if accelerator.is_main_process: |
| if save_model: |
| print("Генерация сэмплов до старта обучения...") |
| generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0) |
| accelerator.wait_for_everyone() |
|
|
| def save_checkpoint(unet, variant=""): |
| if accelerator.is_main_process: |
| if lora_name: |
| save_lora_checkpoint(unet) |
| else: |
| model_to_save = None |
| if not torch_compile: |
| model_to_save = accelerator.unwrap_model(unet) |
| else: |
| model_to_save = unet |
|
|
| if variant != "": |
| model_to_save.to(dtype=torch.float16).save_pretrained( |
| os.path.join(checkpoints_folder, f"{project}"), variant=variant |
| ) |
| else: |
| model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}")) |
|
|
| unet = unet.to(dtype=dtype) |
|
|
| |
| if accelerator.is_main_process: |
| print(f"Total steps per GPU: {total_training_steps}") |
|
|
| epoch_loss_points = [] |
| progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step") |
|
|
| steps_per_epoch = len(dataloader) |
| sample_interval = max(1, steps_per_epoch // sample_interval_share) |
| min_loss = 2. |
|
|
| for epoch in range(start_epoch, start_epoch + num_epochs): |
| batch_losses = [] |
| batch_grads = [] |
| batch_sampler.set_epoch(epoch) |
| accelerator.wait_for_everyone() |
| unet.train() |
| |
| for step, (latents, embeddings, attention_mask) in enumerate(dataloader): |
| with accelerator.accumulate(unet): |
| if save_model == False and epoch == 0 and step == 5 : |
| used_gb = torch.cuda.max_memory_allocated() / 1024**3 |
| print(f"Шаг {step}: {used_gb:.2f} GB") |
| |
| |
| noise = torch.randn_like(latents, dtype=latents.dtype) |
| |
| t = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype) |
| |
| |
| |
| |
| noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise |
| |
| timesteps = (t * scheduler.config.num_train_timesteps).long() |
| |
| |
| model_pred = unet( |
| noisy_latents, |
| timesteps, |
| encoder_hidden_states=embeddings, |
| encoder_attention_mask=attention_mask |
| ).sample |
| |
| target = noise - latents |
|
|
| mse_loss = F.mse_loss(model_pred.float(), target.float()) |
| mae_loss = F.l1_loss(model_pred.float(), target.float()) |
| batch_losses.append(mse_loss.detach().item()) |
|
|
| if (global_step % 100 == 0) or (global_step % sample_interval == 0): |
| accelerator.wait_for_everyone() |
|
|
| losses_dict = {} |
| losses_dict["mse"] = mse_loss |
| losses_dict["mae"] = mae_loss |
|
|
| |
| abs_for_norm = {k: losses_dict.get(k, torch.tensor(0.0, device=device)) for k in normalizer.ratios.keys()} |
| total_loss, coeffs, meds = normalizer.update_and_total(abs_for_norm) |
|
|
| if (global_step % 100 == 0) or (global_step % sample_interval == 0): |
| accelerator.wait_for_everyone() |
| |
| accelerator.backward(total_loss) |
|
|
| if (global_step % 100 == 0) or (global_step % sample_interval == 0): |
| accelerator.wait_for_everyone() |
| |
| grad = 0.0 |
| if not fbp: |
| if accelerator.sync_gradients: |
| |
| grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm) |
| grad = float(grad_val) |
| optimizer.step() |
| lr_scheduler.step() |
| optimizer.zero_grad(set_to_none=True) |
|
|
| if accelerator.sync_gradients: |
| global_step += 1 |
| progress_bar.update(1) |
| if accelerator.is_main_process: |
| if fbp: |
| current_lr = base_learning_rate |
| else: |
| current_lr = lr_scheduler.get_last_lr()[0] |
| batch_grads.append(grad) |
| |
| log_data = {} |
| log_data["loss_mse"] = mse_loss.detach().item() |
| log_data["loss_mae"] = mae_loss.detach().item() |
| log_data["lr"] = current_lr |
| if not fbp: |
| log_data["grad"] = grad |
| log_data["loss_norm"] = float(total_loss.item()) |
| for k, c in coeffs.items(): |
| log_data[f"coeff_{k}"] = float(c) |
| if accelerator.sync_gradients: |
| if use_wandb: |
| wandb.log(log_data, step=global_step) |
| if use_comet_ml: |
| comet_experiment.log_metrics(log_data, step=global_step) |
|
|
| if global_step % sample_interval == 0: |
| |
| if save_model: |
| generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step) |
| elif epoch % 10 == 0: |
| generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step) |
| last_n = sample_interval |
| |
| if save_model: |
| has_losses = len(batch_losses) > 0 |
| avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if has_losses else 0.0 |
| last_loss = batch_losses[-1] if has_losses else 0.0 |
| max_loss = max(avg_sample_loss, last_loss) |
| should_save = max_loss < min_loss * save_barrier |
| print( |
| f"Saving: {should_save} | Max: {max_loss:.4f} | " |
| f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}" |
| ) |
| |
| if should_save: |
| min_loss = max_loss |
| save_checkpoint(unet) |
|
|
| if accelerator.is_main_process: |
| avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0 |
| avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0 |
|
|
| print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}") |
| log_data_ep = { |
| "epoch_loss": avg_epoch_loss, |
| "epoch_grad": avg_epoch_grad, |
| "epoch": epoch + 1, |
| } |
| if use_wandb: |
| wandb.log(log_data_ep) |
| if use_comet_ml: |
| comet_experiment.log_metrics(log_data_ep) |
|
|
| if accelerator.is_main_process: |
| print("Обучение завершено! Сохраняем финальную модель...") |
| |
| save_checkpoint(unet,"fp16") |
| if use_comet_ml: |
| comet_experiment.end() |
| accelerator.free_memory() |
| if torch.distributed.is_initialized(): |
| torch.distributed.destroy_process_group() |
| |
| print("Готово!") |