| |
| |
|
|
| import argparse |
| import copy |
| import json |
| import logging |
| import math |
| import os |
| import random |
| import shutil |
| from contextlib import nullcontext |
| from pathlib import Path |
|
|
| import datasets |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
| import transformers |
| from accelerate import Accelerator |
| from accelerate.logging import get_logger |
| from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed |
| from datasets import load_dataset |
| from packaging import version |
| from torchvision import transforms |
| from torchvision.transforms.functional import crop |
| from tqdm.auto import tqdm |
| from transformers import CLIPTokenizer, T5TokenizerFast |
| from PIL import Image |
| from transformers.configuration_utils import PretrainedConfig |
|
|
| import diffusers |
| from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL |
| from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler |
| from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel |
| from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline |
| from diffusers.optimization import get_scheduler |
| from diffusers.training_utils import ( |
| cast_training_params, |
| compute_density_for_timestep_sampling, |
| compute_loss_weighting_for_sd3, |
| ) |
| from diffusers.utils import check_min_version |
| from diffusers.utils.import_utils import is_wandb_available |
| from peft import LoraConfig |
| from peft.utils import set_peft_model_state_dict |
|
|
| |
| check_min_version("0.30.0") |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| """ |
| 多头自注意力模块,用于 SIT 块中的特征处理 |
| """ |
| |
| def __init__(self, hidden_size, num_attention_heads): |
| """ |
| 初始化多头注意力模块 |
| |
| Args: |
| hidden_size: 隐藏层维度 |
| num_attention_heads: 注意力头数量 |
| """ |
| super().__init__() |
| self.num_attention_heads = num_attention_heads |
| self.attention_head_size = hidden_size // num_attention_heads |
| self.all_head_size = self.num_attention_heads * self.attention_head_size |
| |
| |
| self.query = nn.Linear(hidden_size, self.all_head_size) |
| self.key = nn.Linear(hidden_size, self.all_head_size) |
| self.value = nn.Linear(hidden_size, self.all_head_size) |
| self.dropout = nn.Dropout(0.1) |
| |
| |
| self.output_projection = nn.Linear(self.all_head_size, hidden_size) |
| |
| def transpose_for_scores(self, x): |
| """重塑张量以适应多头注意力计算""" |
| new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) |
| x = x.view(*new_x_shape) |
| return x.permute(0, 2, 1, 3) |
| |
| def forward(self, query_states, key_states, value_states): |
| """ |
| 前向传播计算多头注意力 |
| |
| Args: |
| query_states: 查询状态张量 |
| key_states: 键状态张量 |
| value_states: 值状态张量 |
| |
| Returns: |
| 注意力输出张量 |
| """ |
| |
| query_layer = self.transpose_for_scores(self.query(query_states)) |
| key_layer = self.transpose_for_scores(self.key(key_states)) |
| value_layer = self.transpose_for_scores(self.value(value_states)) |
| |
| |
| attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
| attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
| attention_probs = nn.Softmax(dim=-1)(attention_scores) |
| attention_probs = self.dropout(attention_probs) |
| |
| |
| context_layer = torch.matmul(attention_probs, value_layer) |
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| context_layer = context_layer.view(*new_context_layer_shape) |
| |
| |
| output = self.output_projection(context_layer) |
| |
| return output |
|
|
|
|
| class SITBlock(nn.Module): |
| """ |
| Scalable Interpolant Transformer 块,用于修正噪声训练 |
| |
| 实现标准的 Transformer 块结构,包含多头自注意力机制和前馈网络,以及残差连接和层归一化。 |
| """ |
| |
| def __init__(self, hidden_size, num_attention_heads, intermediate_size): |
| """ |
| 初始化 SIT 块 |
| |
| Args: |
| hidden_size: 隐藏层维度 |
| num_attention_heads: 注意力头数量 |
| intermediate_size: 前馈网络中间层维度 |
| """ |
| super().__init__() |
| self.attention = MultiHeadAttention(hidden_size, num_attention_heads) |
| self.layernorm1 = nn.LayerNorm(hidden_size) |
| self.layernorm2 = nn.LayerNorm(hidden_size) |
| |
| |
| self.feed_forward = nn.Sequential( |
| nn.Linear(hidden_size, intermediate_size), |
| nn.GELU(), |
| nn.Linear(intermediate_size, hidden_size) |
| ) |
| |
| def forward(self, hidden_states, timestep=None): |
| """ |
| SIT块的前向传播 |
| |
| Args: |
| hidden_states: 输入的隐藏状态 |
| timestep: 时间步(预留参数) |
| |
| Returns: |
| 处理后的隐藏状态 |
| """ |
| |
| residual = hidden_states |
| hidden_states = self.layernorm1(hidden_states) |
| attention_output = self.attention(hidden_states, hidden_states, hidden_states) |
| hidden_states = residual + attention_output |
| |
| |
| residual = hidden_states |
| hidden_states = self.layernorm2(hidden_states) |
| feed_forward_output = self.feed_forward(hidden_states) |
| hidden_states = residual + feed_forward_output |
| |
| return hidden_states |
|
|
|
|
| class RectifiedNoiseModule(nn.Module): |
| """ |
| 修正噪声模块,采用 SIT 块和双输出机制 |
| |
| 核心模块,负责处理 SD3 的中间特征并生成修正的噪声预测。采用双输出机制(均值+方差)和重参数化技巧。 |
| """ |
| |
| def __init__(self, hidden_size, num_sit_layers=1, num_attention_heads=16, input_dim=None, transformer_hidden_size=None): |
| """ |
| 初始化修正噪声模块 |
| |
| Args: |
| hidden_size: SIT块的内部隐藏层维度 |
| num_sit_layers: SIT块的层数 |
| num_attention_heads: 每个SIT块中的注意力头数量 |
| input_dim: 最终输出的维度 |
| transformer_hidden_size: SD3 transformer的hidden_size |
| """ |
| super().__init__() |
| self.num_sit_layers = num_sit_layers |
| self.hidden_size = hidden_size |
| self.gradient_checkpointing = False |
| |
| |
| self.input_dim = input_dim or 16 |
| self.transformer_hidden_size = transformer_hidden_size or 1536 |
| |
| |
| self.input_projection = None |
| if self.transformer_hidden_size != self.hidden_size: |
| self.input_projection = nn.Linear(self.transformer_hidden_size, self.hidden_size) |
| |
| |
| if self.hidden_size % num_attention_heads != 0: |
| num_attention_heads = 32 |
| |
| |
| self.sit_blocks = nn.ModuleList([ |
| SITBlock(self.hidden_size, num_attention_heads, self.hidden_size * 4) |
| for _ in range(num_sit_layers) |
| ]) |
| |
| |
| self.dual_output_layer = nn.Linear(self.hidden_size, self.hidden_size * 2) |
| |
| |
| self.output_projection = nn.Linear(self.hidden_size, self.transformer_hidden_size) |
| |
| def gradient_checkpointing_enable(self): |
| self.gradient_checkpointing = True |
| |
| def gradient_checkpointing_disable(self): |
| self.gradient_checkpointing = False |
| |
| def forward(self, intermediate_features): |
| """ |
| 修正噪声模块的前向传播 |
| |
| Args: |
| intermediate_features: SD3的中间特征 |
| |
| Returns: |
| tuple: (rectified_output, mean_output, var_output) |
| """ |
| batch_size, seq_len, input_channels = intermediate_features.shape |
| device = intermediate_features.device |
| dtype = intermediate_features.dtype |
| |
| if self.input_projection is not None: |
| hidden_states = self.input_projection(intermediate_features) |
| else: |
| hidden_states = intermediate_features |
| |
| |
| for i, block in enumerate(self.sit_blocks): |
| hidden_states = block(hidden_states) |
| |
| |
| dual_output = self.dual_output_layer(hidden_states) |
|
|
| |
| mean_output, var_output = torch.chunk(dual_output, 2, dim=-1) |
|
|
| |
| rectified_output = self.output_projection(mean_output) |
| mean_output = self.output_projection(mean_output) |
| var_output = self.output_projection(var_output) |
| |
| |
| rectified_output = rectified_output |
| mean_output = mean_output |
| var_output = var_output |
| |
| return rectified_output, mean_output, var_output |
|
|
|
|
| class SD3WithRectifiedNoise(nn.Module): |
| """ |
| 集成 Rectified Noise 能力的 SD3 模型 |
| |
| 将原始的 SD3 Transformer 模型与 Rectified Noise 模块集成,冻结 SD3 参数,仅训练 SIT 块。 |
| """ |
| |
| def __init__(self, transformer, rectified_noise_module): |
| """ |
| 初始化集成模型 |
| |
| Args: |
| transformer: 原始的 SD3 Transformer 模型 |
| rectified_noise_module: 修正噪声模块 |
| """ |
| super().__init__() |
| self.transformer = transformer |
| self.rectified_noise_module = rectified_noise_module |
| self.intermediate_features = None |
| |
| |
| for param in self.transformer.parameters(): |
| param.requires_grad = False |
| |
| |
| self._register_hooks() |
| |
| def _register_hooks(self): |
| """注册 hook 来捕获 norm_out 之后的中间特征""" |
| def norm_out_hook(module, input, output): |
| self.intermediate_features = output.clone() |
| |
| if hasattr(self.transformer, 'norm_out'): |
| self.transformer.norm_out.register_forward_hook(norm_out_hook) |
| |
| def forward(self, hidden_states, timestep, encoder_hidden_states, pooled_projections, return_dict=False, skip_layers=None): |
| """ |
| 集成模型的前向传播 |
| |
| Args: |
| hidden_states: 输入的隐藏状态 |
| timestep: 时间步 |
| encoder_hidden_states: 编码器隐藏状态 |
| pooled_projections: 池化投射 |
| return_dict: 是否返回字典格式的结果 |
| skip_layers: 要跳过的层列表(为兼容性保留,但在此实现中不使用) |
| |
| Returns: |
| 模型输出元组或字典 |
| """ |
| batch_size, channels, height, width = hidden_states.shape |
| self.intermediate_features = None |
| |
| |
| |
| if hidden_states.shape[0] != encoder_hidden_states.shape[0]: |
| |
| actual_batch_size = hidden_states.shape[0] // 2 if self.training else hidden_states.shape[0] |
| if encoder_hidden_states.shape[0] == actual_batch_size: |
| |
| encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states], dim=0) |
| pooled_projections = torch.cat([pooled_projections, pooled_projections], dim=0) |
| elif encoder_hidden_states.shape[0] == hidden_states.shape[0]: |
| |
| pass |
| else: |
| |
| print(f"警告:batch size不匹配,hidden_states: {hidden_states.shape}, encoder_hidden_states: {encoder_hidden_states.shape}") |
| |
| |
| final_output, mean_out, var_out = self._forward_with_sit_correction( |
| hidden_states, timestep, encoder_hidden_states, pooled_projections |
| ) |
| |
| if return_dict: |
| return { |
| "sample": final_output, |
| "mean_output": mean_out, |
| "var_output": var_out, |
| } |
| else: |
| return (final_output, mean_out, var_out) |
| |
| def _forward_with_sit_correction(self, hidden_states, timestep, encoder_hidden_states, pooled_projections): |
| """执行SD3 forward流程并在norm_out后应用SIT修正""" |
| batch_size, channels, height, width = hidden_states.shape |
| |
| |
| original_output = self.transformer( |
| hidden_states=hidden_states, |
| timestep=timestep, |
| encoder_hidden_states=encoder_hidden_states, |
| pooled_projections=pooled_projections, |
| return_dict=False |
| )[0] |
| |
| |
| if self.intermediate_features is None: |
|
|
| self.intermediate_features = None |
| |
| |
| temp_hidden_states = hidden_states |
| if hasattr(self.transformer, 'pos_embed'): |
| temp_hidden_states = self.transformer.pos_embed(temp_hidden_states) |
| |
| if hasattr(self.transformer, 'time_text_embed'): |
| temb = self.transformer.time_text_embed(timestep, pooled_projections) |
| else: |
| temb = None |
| |
| if hasattr(self.transformer, 'context_embedder'): |
| temp_encoder_hidden_states = self.transformer.context_embedder(encoder_hidden_states) |
| else: |
| temp_encoder_hidden_states = encoder_hidden_states |
| |
| if hasattr(self.transformer, 'transformer_blocks'): |
| for i, block in enumerate(self.transformer.transformer_blocks): |
| if temb is not None: |
| temp_encoder_hidden_states, temp_hidden_states = block( |
| temp_hidden_states, temp_encoder_hidden_states, temb |
| ) |
| else: |
| temp_encoder_hidden_states, temp_hidden_states = block( |
| temp_hidden_states, temp_encoder_hidden_states |
| ) |
|
|
| if hasattr(self.transformer, 'norm_out'): |
| if temb is not None: |
| temp_hidden_states = self.transformer.norm_out(temp_hidden_states, temb) |
| else: |
| temp_hidden_states = self.transformer.norm_out(temp_hidden_states) |
| |
| self.intermediate_features = temp_hidden_states |
| else: |
| |
| if torch.isnan(self.intermediate_features).any(): |
| self.intermediate_features = torch.nan_to_num(self.intermediate_features, nan=0.0, posinf=1e6, neginf=-1e6) |
| |
| |
| rectified_s_seq, mean_out_seq, var_out_seq = self.rectified_noise_module(self.intermediate_features) |
|
|
| |
| final_output = original_output + rectified_s_seq.mean() |
| mean_out = torch.zeros_like(original_output) + mean_out_seq.mean() |
| var_out = torch.zeros_like(original_output) + var_out_seq.mean() |
| |
|
|
| return final_output, mean_out, var_out |
|
|
|
|
| def log_validation( |
| pipeline, |
| args, |
| accelerator, |
| epoch, |
| is_final_validation=False, |
| global_step=None, |
| ): |
| """运行验证并记录生成的图像 |
| |
| Args: |
| pipeline: 用于生成图像的 StableDiffusion3Pipeline |
| args: 训练参数 |
| accelerator: Accelerator 实例 |
| epoch: 当前 epoch |
| is_final_validation: 是否为最终验证 |
| global_step: 全局训练步数 |
| |
| Returns: |
| 生成的图像列表 |
| """ |
| logger.info( |
| f"Running validation... \n Generating {args.num_validation_images} images with prompt:" |
| f" {args.validation_prompt}." |
| ) |
| pipeline = pipeline.to(accelerator.device) |
| pipeline.set_progress_bar_config(disable=True) |
|
|
| |
| generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None |
| pipeline_args = {"prompt": args.validation_prompt} |
| |
| |
| if torch.backends.mps.is_available(): |
| autocast_ctx = nullcontext() |
| else: |
| autocast_ctx = torch.autocast(accelerator.device.type) |
| |
| |
|
|
| with autocast_ctx: |
| images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] |
|
|
| |
| if accelerator.is_main_process: |
| validation_dir = os.path.join(args.output_dir, "validation_images") |
| os.makedirs(validation_dir, exist_ok=True) |
| for i, image in enumerate(images): |
| |
| if global_step is not None: |
| filename = f"validation_step_{global_step}_epoch_{epoch}_img_{i}.png" |
| else: |
| filename = f"validation_epoch_{epoch}_img_{i}.png" |
| |
| image_path = os.path.join(validation_dir, filename) |
| image.save(image_path) |
| logger.info(f"Saved validation image: {image_path}") |
|
|
| |
| for tracker in accelerator.trackers if hasattr(accelerator, 'trackers') and accelerator.trackers else []: |
| phase_name = "test" if is_final_validation else "validation" |
| try: |
| if tracker.name == "tensorboard": |
| np_images = np.stack([np.asarray(img) for img in images]) |
| tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") |
| if tracker.name == "wandb": |
| import wandb |
| tracker.log( |
| { |
| phase_name: [ |
| wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) |
| ] |
| } |
| ) |
| except Exception as e: |
| logger.warning(f"Failed to log to {tracker.name}: {e}") |
| |
| |
| del pipeline |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| return images |
|
|
|
|
| def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"): |
| """ |
| 从预训练模型路径导入正确的文本编码器类 |
| |
| Args: |
| pretrained_model_name_or_path: 预训练模型名称或路径 |
| revision: 模型版本 |
| subfolder: 子文件夹名称 |
| |
| Returns: |
| 文本编码器类 |
| """ |
| text_encoder_config = PretrainedConfig.from_pretrained( |
| pretrained_model_name_or_path, subfolder=subfolder, revision=revision |
| ) |
| model_class = text_encoder_config.architectures[0] |
|
|
| if model_class == "CLIPTextModelWithProjection": |
| from transformers import CLIPTextModelWithProjection |
| return CLIPTextModelWithProjection |
| elif model_class == "T5EncoderModel": |
| from transformers import T5EncoderModel |
| return T5EncoderModel |
| else: |
| raise ValueError(f"{model_class} is not supported.") |
|
|
|
|
| DATASET_NAME_MAPPING = { |
| "lambdalabs/naruto-blip-captions": ("image", "text"), |
| } |
|
|
|
|
| def load_dataset_from_jsonl(metadata_path, data_dir, accelerator=None): |
| """ |
| 从 metadata.jsonl 文件加载数据集,避免扫描所有文件。 |
| 这对于大型数据集在分布式训练中非常重要。 |
| |
| 注意:只让主进程读取 jsonl 文件,然后创建数据集。 |
| 其他进程会等待主进程完成后再继续。 |
| |
| Args: |
| metadata_path: metadata.jsonl 文件路径 |
| data_dir: 数据集根目录 |
| accelerator: Accelerator 对象,用于多进程同步 |
| |
| Returns: |
| datasets.DatasetDict |
| """ |
| if accelerator is None or accelerator.is_main_process: |
| logger.info(f"Loading dataset from metadata.jsonl: {metadata_path}") |
| |
| |
| data_list = [] |
| if os.path.exists(metadata_path): |
| with open(metadata_path, 'r', encoding='utf-8') as f: |
| for line_num, line in enumerate(f): |
| try: |
| item = json.loads(line.strip()) |
| file_name = item.get('file_name', '') |
| caption = item.get('caption', '') |
| |
| |
| image_path = os.path.join(data_dir, file_name) |
| |
| |
| |
| |
| |
| data_list.append({ |
| 'image': image_path, |
| 'text': caption |
| }) |
| |
| |
| if (line_num + 1) % 100000 == 0 and (accelerator is None or accelerator.is_main_process): |
| logger.info(f"Processed {line_num + 1} entries from metadata.jsonl") |
| |
| except json.JSONDecodeError as e: |
| if accelerator is None or accelerator.is_main_process: |
| logger.warning(f"Skipping invalid JSON at line {line_num + 1}: {e}") |
| continue |
| |
| if accelerator is None or accelerator.is_main_process: |
| logger.info(f"Loaded {len(data_list)} image-caption pairs from metadata.jsonl") |
| else: |
| raise FileNotFoundError(f"metadata.jsonl not found at: {metadata_path}") |
| |
| |
| |
| |
| dataset = datasets.Dataset.from_list(data_list) |
| |
| return datasets.DatasetDict({'train': dataset}) |
|
|
|
|
| def parse_args(input_args=None): |
| """ |
| 解析命令行参数 |
| |
| Returns: |
| 解析后的参数对象 |
| """ |
| parser = argparse.ArgumentParser(description="SD3 Rectified Noise 精简训练脚本") |
| |
| |
| parser.add_argument("--pretrained_model_name_or_path", type=str, required=True) |
| parser.add_argument("--lora_model_path", type=str, required=True) |
| parser.add_argument("--revision", type=str, default=None) |
| parser.add_argument("--variant", type=str, default=None) |
| |
| |
| parser.add_argument("--num_sit_layers", type=int, default=1) |
| parser.add_argument("--sit_learning_rate", type=float, default=1e-4) |
| parser.add_argument("--kl_loss_weight", type=float, default=0.1) |
| parser.add_argument("--time_weight_alpha", type=float, default=2.0, help="时间权重衰减参数,越大衰减越快") |
| parser.add_argument("--save_sit_weights_only", action="store_true") |
| |
| |
| parser.add_argument("--dataset_name", type=str, default=None, help="The name of the Dataset to train on.") |
| parser.add_argument("--dataset_config_name", type=str, default=None, help="The config of the Dataset.") |
| parser.add_argument("--train_data_dir", type=str, default=None) |
| parser.add_argument("--image_column", type=str, default="image") |
| parser.add_argument("--caption_column", type=str, default="caption") |
| |
| |
| parser.add_argument("--max_sequence_length", type=int, default=77) |
| parser.add_argument("--validation_prompt", type=str, default=None, help="验证期间使用的提示词") |
| parser.add_argument("--num_validation_images", type=int, default=4, help="验证生成的图像数量") |
| parser.add_argument("--validation_steps", type=int, default=1, help="每 X 个 steps 运行一次验证 (设置为0禁用基于steps的验证)") |
| parser.add_argument("--max_train_samples", type=int, default=None) |
| parser.add_argument("--output_dir", type=str, default="rectified-noise-model") |
| parser.add_argument("--cache_dir", type=str, default=None) |
| parser.add_argument("--seed", type=int, default=None) |
| parser.add_argument("--resolution", type=int, default=1024) |
| parser.add_argument("--center_crop", default=False, action="store_true") |
| parser.add_argument("--random_flip", action="store_true") |
| parser.add_argument("--train_batch_size", type=int, default=1) |
| parser.add_argument("--num_train_epochs", type=int, default=100) |
| parser.add_argument("--max_train_steps", type=int, default=None) |
| parser.add_argument("--checkpointing_steps", type=int, default=500) |
| parser.add_argument("--checkpoints_total_limit", type=int, default=3) |
| parser.add_argument("--resume_from_checkpoint", type=str, default="latest") |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=1) |
| parser.add_argument("--gradient_checkpointing", action="store_true") |
| parser.add_argument("--learning_rate", type=float, default=1e-4) |
| parser.add_argument("--scale_lr", action="store_true", default=False) |
| parser.add_argument("--lr_scheduler", type=str, default="constant") |
| parser.add_argument("--lr_warmup_steps", type=int, default=500) |
| |
| |
| parser.add_argument("--weighting_scheme", type=str, default="logit_normal", |
| choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"]) |
| parser.add_argument("--logit_mean", type=float, default=0.0) |
| parser.add_argument("--logit_std", type=float, default=1.0) |
| parser.add_argument("--mode_scale", type=float, default=1.29) |
| parser.add_argument("--precondition_outputs", type=int, default=1) |
| |
| |
| parser.add_argument("--allow_tf32", action="store_true") |
| parser.add_argument("--dataloader_num_workers", type=int, default=0) |
| parser.add_argument("--use_8bit_adam", action="store_true") |
| parser.add_argument("--adam_beta1", type=float, default=0.9) |
| parser.add_argument("--adam_beta2", type=float, default=0.999) |
| parser.add_argument("--adam_weight_decay", type=float, default=1e-2) |
| parser.add_argument("--adam_epsilon", type=float, default=1e-08) |
| parser.add_argument("--max_grad_norm", default=1.0, type=float) |
| |
| |
| parser.add_argument("--logging_dir", type=str, default="logs") |
| parser.add_argument("--report_to", type=str, default="tensorboard") |
| parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"]) |
| parser.add_argument("--local_rank", type=int, default=-1) |
| |
| if input_args is not None: |
| args = parser.parse_args(input_args) |
| else: |
| args = parser.parse_args() |
|
|
| env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) |
| if env_local_rank != -1 and env_local_rank != args.local_rank: |
| args.local_rank = env_local_rank |
|
|
| |
| if args.dataset_name is None and args.train_data_dir is None: |
| raise ValueError("Need either a dataset name or a training folder.") |
| |
| if not os.path.exists(args.lora_model_path): |
| raise ValueError(f"LoRA model path does not exist: {args.lora_model_path}") |
|
|
| return args |
|
|
|
|
| def check_lora_weights_exist(lora_path): |
| """检查LoRA权重文件是否存在(标准格式)""" |
| if not lora_path: |
| return False |
| |
| if os.path.isdir(lora_path): |
| |
| weight_file = os.path.join(lora_path, "pytorch_lora_weights.safetensors") |
| if os.path.exists(weight_file): |
| return True |
| |
| for file in os.listdir(lora_path): |
| if file.endswith(".safetensors") and "lora" in file.lower(): |
| return True |
| return False |
| elif os.path.isfile(lora_path): |
| return lora_path.endswith(".safetensors") |
| |
| return False |
|
|
|
|
| def load_lora_from_checkpoint(pipeline, checkpoint_path, lora_rank=64): |
| """ |
| 从accelerator checkpoint目录加载LoRA权重或完整模型权重 |
| 参考 sample_sd3_rectified_ddp.py 中的实现 |
| """ |
| logger.info(f"Loading weights from accelerator checkpoint: {checkpoint_path}") |
| |
| try: |
| from safetensors.torch import load_file |
| model_file = os.path.join(checkpoint_path, "model.safetensors") |
| if not os.path.exists(model_file): |
| logger.error(f"Model file not found: {model_file}") |
| return False |
| |
| |
| state_dict = load_file(model_file) |
| all_keys = list(state_dict.keys()) |
| |
| |
| lora_keys = [k for k in all_keys if 'lora' in k.lower() and 'transformer' in k.lower()] |
| base_layer_keys = [k for k in all_keys if 'base_layer' in k.lower() and 'transformer' in k.lower()] |
| non_lora_transformer_keys = [k for k in all_keys if 'lora' not in k.lower() and 'base_layer' not in k.lower() and 'transformer' in k.lower()] |
| |
| logger.info(f"Checkpoint analysis: Total={len(all_keys)}, LoRA={len(lora_keys)}, BaseLayer={len(base_layer_keys)}, Merged={len(non_lora_transformer_keys)}") |
| |
| |
| if len(base_layer_keys) > 0: |
| logger.info("Detected PEFT format (base_layer + LoRA), merging weights...") |
| |
| |
| merged_state_dict = {} |
| modules_to_merge = {} |
| non_lora_keys_found = [] |
| |
| for key in all_keys: |
| new_key = key |
| has_transformer_prefix = False |
| |
| if key.startswith('base_model.model.transformer.'): |
| new_key = key[len('base_model.model.transformer.'):] |
| has_transformer_prefix = True |
| elif key.startswith('model.transformer.'): |
| new_key = key[len('model.transformer.'):] |
| has_transformer_prefix = True |
| elif key.startswith('transformer.'): |
| new_key = key[len('transformer.'):] |
| has_transformer_prefix = True |
| elif 'transformer' in key.lower(): |
| has_transformer_prefix = True |
| |
| if not has_transformer_prefix: |
| continue |
| |
| |
| if '.base_layer.weight' in new_key: |
| module_key = new_key.replace('.base_layer.weight', '.weight') |
| if module_key not in modules_to_merge: |
| modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None} |
| modules_to_merge[module_key]['base_weight'] = (key, state_dict[key]) |
| elif '.base_layer.bias' in new_key: |
| module_key = new_key.replace('.base_layer.bias', '.bias') |
| if module_key not in modules_to_merge: |
| modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None} |
| modules_to_merge[module_key]['base_bias'] = (key, state_dict[key]) |
| elif '.lora_A.default.weight' in new_key: |
| module_key = new_key.replace('.lora_A.default.weight', '.weight') |
| if module_key not in modules_to_merge: |
| modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None} |
| modules_to_merge[module_key]['lora_A'] = (key, state_dict[key]) |
| elif '.lora_B.default.weight' in new_key: |
| module_key = new_key.replace('.lora_B.default.weight', '.weight') |
| if module_key not in modules_to_merge: |
| modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None} |
| modules_to_merge[module_key]['lora_B'] = (key, state_dict[key]) |
| elif 'lora' not in new_key.lower() and 'base_layer' not in new_key.lower(): |
| merged_state_dict[new_key] = state_dict[key] |
| non_lora_keys_found.append(new_key) |
| |
| logger.info(f"Found {len(non_lora_keys_found)} non-LoRA transformer keys") |
| logger.info(f"Merging {len(modules_to_merge)} modules...") |
| |
| |
| for module_key, weights in modules_to_merge.items(): |
| if weights['base_weight'] is not None: |
| base_key, base_weight = weights['base_weight'] |
| base_weight = base_weight.clone() |
| |
| if weights['lora_A'] is not None and weights['lora_B'] is not None: |
| lora_A_key, lora_A = weights['lora_A'] |
| lora_B_key, lora_B = weights['lora_B'] |
| |
| rank_value = lora_A.shape[0] |
| alpha = rank_value |
| |
| lora_delta = torch.matmul(lora_B, lora_A) |
| |
| if lora_delta.shape == base_weight.shape: |
| merged_weight = base_weight + lora_delta * (alpha / rank_value) |
| merged_state_dict[module_key] = merged_weight |
| else: |
| logger.warning(f"Shape mismatch for {module_key}: base={base_weight.shape}, lora_delta={lora_delta.shape}, using base only") |
| merged_state_dict[module_key] = base_weight |
| else: |
| merged_state_dict[module_key] = base_weight |
| |
| if '.bias' in module_key and weights['base_bias'] is not None: |
| bias_key, base_bias = weights['base_bias'] |
| merged_state_dict[module_key] = base_bias.clone() |
| |
| logger.info(f"Merged {len(merged_state_dict)} weights") |
| |
| |
| try: |
| missing_keys, unexpected_keys = pipeline.transformer.load_state_dict(merged_state_dict, strict=False) |
| |
| logger.info(f"Loaded merged weights: Missing={len(missing_keys)}, Unexpected={len(unexpected_keys)}") |
| if missing_keys: |
| critical_keys = ['pos_embed', 'time_text_embed', 'context_embedder', 'norm_out', 'proj_out'] |
| has_critical = any(any(ck in mk for ck in critical_keys) for mk in missing_keys) |
| if has_critical: |
| logger.warning("Missing critical keys! These will use pretrained model values.") |
| |
| logger.info("Successfully loaded merged model weights") |
| return True |
| |
| except Exception as e: |
| logger.error(f"Error loading merged weights: {e}") |
| import traceback |
| traceback.print_exc() |
| return False |
| |
| |
| elif len(non_lora_transformer_keys) > 0: |
| logger.info("Detected merged model weights (contains full transformer weights)") |
| |
| transformer_state_dict = {} |
| for key, value in state_dict.items(): |
| new_key = key |
| if key.startswith('base_model.model.transformer.'): |
| new_key = key[len('base_model.model.transformer.'):] |
| elif key.startswith('model.transformer.'): |
| new_key = key[len('model.transformer.'):] |
| elif key.startswith('transformer.'): |
| new_key = key[len('transformer.'):] |
| |
| if (new_key.startswith('transformer_blocks') or |
| new_key.startswith('pos_embed') or |
| new_key.startswith('time_text_embed') or |
| 'lora' in new_key.lower()): |
| transformer_state_dict[new_key] = value |
| |
| logger.info(f"Extracted {len(transformer_state_dict)} transformer weight keys") |
| |
| try: |
| missing_keys, unexpected_keys = pipeline.transformer.load_state_dict(transformer_state_dict, strict=False) |
| |
| logger.info(f"Loaded full model weights: Missing={len(missing_keys)}, Unexpected={len(unexpected_keys)}") |
| |
| if len(missing_keys) > len(transformer_state_dict) * 0.5: |
| logger.warning("Too many missing keys, weights may not be fully loaded") |
| return False |
| |
| logger.info("Successfully loaded merged model weights") |
| return True |
| |
| except Exception as e: |
| logger.error(f"Error loading full model weights: {e}") |
| import traceback |
| traceback.print_exc() |
| return False |
| |
| |
| logger.info("Detected LoRA-only weights, loading as LoRA adapter...") |
| |
| |
| detected_rank = None |
| for key, value in state_dict.items(): |
| if 'lora_A' in key and 'transformer' in key and len(value.shape) == 2: |
| detected_rank = value.shape[0] |
| logger.info(f"Detected LoRA rank from checkpoint: {detected_rank} (from key: {key})") |
| break |
| |
| actual_rank = detected_rank if detected_rank is not None else lora_rank |
| if detected_rank is not None and detected_rank != lora_rank: |
| logger.warning(f"Detected rank ({detected_rank}) differs from requested rank ({lora_rank}), using detected rank") |
| |
| |
| if hasattr(pipeline.transformer, 'peft_config') and pipeline.transformer.peft_config: |
| if "default" in pipeline.transformer.peft_config: |
| logger.info("Removing existing 'default' adapter before adding new one...") |
| try: |
| pipeline.unload_lora_weights() |
| logger.info("Successfully unloaded existing LoRA adapter") |
| except Exception as e: |
| logger.error(f"Could not unload existing adapter: {e}") |
| return False |
| |
| |
| transformer_lora_config = LoraConfig( |
| r=actual_rank, |
| lora_alpha=actual_rank, |
| init_lora_weights="gaussian", |
| target_modules=["attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0"], |
| ) |
| |
| pipeline.transformer.add_adapter(transformer_lora_config) |
| logger.info(f"LoRA adapter configured with rank={actual_rank}") |
| |
| |
| lora_state_dict = {} |
| for key, value in state_dict.items(): |
| if 'lora' in key.lower() and 'transformer' in key.lower(): |
| new_key = key |
| if key.startswith('base_model.model.transformer.'): |
| new_key = key[len('base_model.model.transformer.'):] |
| elif key.startswith('model.transformer.'): |
| new_key = key[len('model.transformer.'):] |
| elif key.startswith('transformer.'): |
| if not key[len('transformer.'):].startswith('transformer_blocks'): |
| new_key = key[len('transformer.'):] |
| else: |
| new_key = key[len('transformer.'):] |
| |
| if 'transformer_blocks' in new_key or 'transformer' in new_key: |
| lora_state_dict[new_key] = value |
| |
| if not lora_state_dict: |
| logger.error("No LoRA weights found in checkpoint") |
| return False |
| |
| logger.info(f"Found {len(lora_state_dict)} LoRA weight keys") |
| |
| |
| converted_dict = {} |
| for key, value in lora_state_dict.items(): |
| new_key = key |
| if '.default.weight' in new_key: |
| new_key = new_key.replace('.default.weight', '.weight') |
| elif '.default.bias' in new_key: |
| new_key = new_key.replace('.default.bias', '.bias') |
| elif '.default' in new_key and (new_key.endswith('.weight') or new_key.endswith('.bias')): |
| new_key = new_key.replace('.default', '') |
| converted_dict[new_key] = value |
| |
| logger.info(f"Converted {len(converted_dict)} keys (removed .default suffix if present)") |
| |
| |
| try: |
| incompatible_keys = set_peft_model_state_dict( |
| pipeline.transformer, |
| converted_dict, |
| adapter_name="default" |
| ) |
| |
| if incompatible_keys is not None: |
| missing_keys = getattr(incompatible_keys, "missing_keys", []) |
| unexpected_keys = getattr(incompatible_keys, "unexpected_keys", []) |
| |
| logger.info(f"LoRA loading result: Missing={len(missing_keys)}, Unexpected={len(unexpected_keys)}") |
| |
| if len(missing_keys) > len(converted_dict) * 0.5: |
| logger.error("Too many missing keys, LoRA weights not loaded correctly!") |
| return False |
| else: |
| logger.info("LoRA weights loaded (no incompatible keys reported)") |
| |
| except RuntimeError as e: |
| error_str = str(e) |
| if "size mismatch" in error_str: |
| logger.error(f"Size mismatch error: The checkpoint rank doesn't match the adapter rank") |
| import re |
| match = re.search(r'copying a param with shape torch\.Size\(\[(\d+),', error_str) |
| if match: |
| checkpoint_rank = int(match.group(1)) |
| logger.error(f"Detected checkpoint rank: {checkpoint_rank}, Adapter was configured with rank: {actual_rank}") |
| else: |
| logger.error(f"Error setting LoRA state dict: {e}") |
| import traceback |
| traceback.print_exc() |
| try: |
| pipeline.unload_lora_weights() |
| except: |
| pass |
| return False |
| except Exception as e: |
| logger.error(f"Error setting LoRA state dict: {e}") |
| import traceback |
| traceback.print_exc() |
| try: |
| pipeline.unload_lora_weights() |
| except: |
| pass |
| return False |
| |
| |
| pipeline.transformer.set_adapter("default") |
| logger.info("Successfully loaded and verified LoRA weights from checkpoint") |
| |
| return True |
| |
| except Exception as e: |
| logger.error(f"Error loading LoRA from checkpoint: {e}") |
| import traceback |
| traceback.print_exc() |
| return False |
|
|
|
|
| def manage_checkpoints(output_dir, checkpoints_total_limit): |
| """ |
| 管理检查点数量,只保留最新的 N 个检查点 |
| |
| Args: |
| output_dir: 输出目录路径 |
| checkpoints_total_limit: 最大检查点数量 |
| |
| Returns: |
| tuple: (已删除的检查点数量, 剩余检查点数量) |
| """ |
| if checkpoints_total_limit is None: |
| return 0, 0 |
| |
| try: |
| if not os.path.exists(output_dir): |
| return 0, 0 |
| |
| checkpoints = [] |
| for item in os.listdir(output_dir): |
| if item.startswith("checkpoint-") and os.path.isdir(os.path.join(output_dir, item)): |
| try: |
| step_num = int(item.split("-")[1]) |
| checkpoints.append((step_num, item)) |
| except (ValueError, IndexError): |
| continue |
| |
| |
| checkpoints.sort(key=lambda x: x[0]) |
| |
| |
| current_count = len(checkpoints) |
| if current_count >= checkpoints_total_limit: |
| num_to_remove = current_count - checkpoints_total_limit + 1 |
| removing_checkpoints = checkpoints[:num_to_remove] |
| |
| removed_count = 0 |
| |
| for step_num, checkpoint_name in removing_checkpoints: |
| checkpoint_path = os.path.join(output_dir, checkpoint_name) |
| try: |
| if os.path.exists(checkpoint_path): |
| shutil.rmtree(checkpoint_path) |
| removed_count += 1 |
| except Exception: |
| pass |
| |
| remaining_count = current_count - removed_count |
| return removed_count, remaining_count |
| else: |
| return 0, current_count |
| |
| except Exception: |
| return 0, 0 |
|
|
|
|
| def save_sit_weights(model, save_path): |
| """ |
| 保存 SIT 块的权重和配置信息 |
| |
| Args: |
| model: 集成模型,包含 SIT 模块 |
| save_path: 保存路径 |
| """ |
| try: |
| |
| os.makedirs(save_path, exist_ok=True) |
| |
| |
| unwrapped_model = model |
| if hasattr(model, 'module'): |
| unwrapped_model = model.module |
| elif hasattr(model, '_orig_mod'): |
| unwrapped_model = model._orig_mod |
| |
| |
| sit_state_dict = {} |
| for name, param in unwrapped_model.rectified_noise_module.named_parameters(): |
| sit_state_dict[name] = param.cpu().clone() |
| |
| |
| try: |
| from safetensors.torch import save_file |
| weights_path = os.path.join(save_path, "pytorch_sit_weights.safetensors") |
| save_file(sit_state_dict, weights_path) |
| logger.info(f"保存 SIT 权重: {weights_path}") |
| except ImportError: |
| weights_path = os.path.join(save_path, "pytorch_sit_weights.bin") |
| torch.save(sit_state_dict, weights_path) |
| logger.info(f"保存 SIT 权重: {weights_path}") |
| |
| |
| config = { |
| "num_sit_layers": unwrapped_model.rectified_noise_module.num_sit_layers, |
| "hidden_size": unwrapped_model.rectified_noise_module.hidden_size, |
| "input_dim": unwrapped_model.rectified_noise_module.input_dim, |
| "num_attention_heads": 16, |
| "intermediate_size": unwrapped_model.rectified_noise_module.hidden_size * 4, |
| "model_type": "rectified_noise", |
| "architecture": "SIT", |
| "version": "1.0" |
| } |
| |
| config_path = os.path.join(save_path, "sit_config.json") |
| with open(config_path, "w", encoding="utf-8") as f: |
| json.dump(config, f, indent=2, ensure_ascii=False) |
| |
| logger.info(f"SIT 权重和配置保存到: {save_path}") |
| |
| except Exception as e: |
| logger.error(f"保存 SIT 权重时出错: {e}") |
| raise |
|
|
|
|
| def tokenize_prompt(tokenizer, prompt): |
| """ |
| 使用指定的分词器对提示词进行分词 |
| |
| Args: |
| tokenizer: 分词器实例 |
| prompt: 输入的提示词文本 |
| |
| Returns: |
| 分词后的 token IDs |
| """ |
| text_inputs = tokenizer( |
| prompt, |
| padding="max_length", |
| max_length=77, |
| truncation=True, |
| return_tensors="pt", |
| ) |
| return text_inputs.input_ids |
|
|
|
|
| def _encode_prompt_with_t5(text_encoder, tokenizer, max_sequence_length, prompt=None, |
| num_images_per_prompt=1, device=None, text_input_ids=None): |
| """ |
| 使用 T5 文本编码器编码提示词 |
| |
| Args: |
| text_encoder: T5 文本编码器 |
| tokenizer: T5 分词器 |
| max_sequence_length: 最大序列长度 |
| prompt: 输入提示词 |
| num_images_per_prompt: 每个提示词生成的图像数量 |
| device: 计算设备 |
| text_input_ids: 预分词的 token IDs |
| |
| Returns: |
| 编码后的提示词嵌入 |
| """ |
| if prompt is not None: |
| prompt = [prompt] if isinstance(prompt, str) else prompt |
| batch_size = len(prompt) |
| else: |
| if text_input_ids is None: |
| raise ValueError("Either prompt or text_input_ids must be provided") |
| batch_size = text_input_ids.shape[0] |
|
|
| if tokenizer is not None and prompt is not None: |
| text_inputs = tokenizer( |
| prompt, |
| padding="max_length", |
| max_length=256, |
| truncation=True, |
| add_special_tokens=True, |
| return_tensors="pt", |
| ) |
| text_input_ids = text_inputs.input_ids |
| else: |
| if text_input_ids is None: |
| raise ValueError("text_input_ids must be provided") |
|
|
| prompt_embeds = text_encoder(text_input_ids.to(device))[0] |
| dtype = text_encoder.dtype |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) |
|
|
| _, seq_len, _ = prompt_embeds.shape |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
| prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
|
|
| return prompt_embeds |
|
|
|
|
| def _encode_prompt_with_clip(text_encoder, tokenizer, prompt, device=None, |
| text_input_ids=None, num_images_per_prompt: int = 1): |
| """ |
| 使用 CLIP 文本编码器编码提示词 |
| |
| Args: |
| text_encoder: CLIP 文本编码器 |
| tokenizer: CLIP 分词器 |
| prompt: 输入提示词 |
| device: 计算设备 |
| text_input_ids: 预分词的 token IDs |
| num_images_per_prompt: 每个提示词生成的图像数量 |
| |
| Returns: |
| tuple: (prompt_embeds, pooled_prompt_embeds) |
| """ |
| if prompt is not None: |
| prompt = [prompt] if isinstance(prompt, str) else prompt |
| batch_size = len(prompt) |
| else: |
| if text_input_ids is None: |
| raise ValueError("Either prompt or text_input_ids must be provided") |
| batch_size = text_input_ids.shape[0] |
|
|
| if tokenizer is not None and prompt is not None: |
| text_inputs = tokenizer( |
| prompt, |
| padding="max_length", |
| max_length=77, |
| truncation=True, |
| return_tensors="pt", |
| ) |
| text_input_ids = text_inputs.input_ids |
| else: |
| if text_input_ids is None: |
| raise ValueError("text_input_ids must be provided") |
|
|
| prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) |
| pooled_prompt_embeds = prompt_embeds[0] |
| prompt_embeds = prompt_embeds.hidden_states[-2] |
| prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) |
|
|
| _, seq_len, _ = prompt_embeds.shape |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
| prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
|
|
| return prompt_embeds, pooled_prompt_embeds |
|
|
|
|
| def encode_prompt(text_encoders, tokenizers, prompt, max_sequence_length, device=None, |
| num_images_per_prompt: int = 1, text_input_ids_list=None): |
| """ |
| 使用所有三个文本编码器编码提示词(SD3 架构) |
| |
| Args: |
| text_encoders: 三个文本编码器的列表 |
| tokenizers: 三个分词器的列表 |
| prompt: 输入提示词 |
| max_sequence_length: T5 的最大序列长度 |
| device: 计算设备 |
| num_images_per_prompt: 每个提示词生成的图像数量 |
| text_input_ids_list: 预分词的 token IDs 列表 |
| |
| Returns: |
| tuple: (prompt_embeds, pooled_prompt_embeds) |
| """ |
| if prompt is not None: |
| prompt = [prompt] if isinstance(prompt, str) else list(prompt) |
|
|
| |
| clip_tokenizers = tokenizers[:2] |
| clip_text_encoders = text_encoders[:2] |
|
|
| clip_prompt_embeds_list = [] |
| clip_pooled_prompt_embeds_list = [] |
| |
| for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)): |
| prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip( |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| prompt=prompt, |
| device=device if device is not None else text_encoder.device, |
| num_images_per_prompt=num_images_per_prompt, |
| text_input_ids=text_input_ids_list[i] if text_input_ids_list else None, |
| ) |
| clip_prompt_embeds_list.append(prompt_embeds) |
| clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds) |
|
|
| |
| clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1) |
| pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1) |
|
|
| |
| t5_prompt_embed = _encode_prompt_with_t5( |
| text_encoders[-1], |
| tokenizers[-1], |
| max_sequence_length, |
| prompt=prompt, |
| num_images_per_prompt=num_images_per_prompt, |
| text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None, |
| device=device if device is not None else text_encoders[-1].device, |
| ) |
|
|
| |
| clip_prompt_embeds = torch.nn.functional.pad( |
| clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) |
| ) |
| |
| prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) |
|
|
| return prompt_embeds, pooled_prompt_embeds |
|
|
|
|
| def compute_loss(model_pred, target, mean_output, var_output, args, sigmas, timesteps=None): |
| """ |
| 计算组合损失函数,包括重构损失和 KL 散度损失 |
| |
| Args: |
| model_pred: 模型预测的噪声 |
| target: 目标噪声 |
| mean_output: 双输出机制的均值输出 |
| var_output: 双输出机制的方差输出 |
| args: 训练参数 |
| sigmas: Flow Matching 中的噪声尺度参数 |
| timesteps: 时间步,用于计算时间权重 |
| |
| Returns: |
| tuple: (total_loss, reconstruction_loss, kl_loss) |
| """ |
| |
| |
| weighting = compute_loss_weighting_for_sd3( |
| weighting_scheme=args.weighting_scheme, sigmas=sigmas |
| ) |
| |
| |
| if timesteps is not None: |
| |
| |
| max_timestep = 1000.0 |
| normalized_t = timesteps.float() / max_timestep |
| normalized_t = torch.clamp(normalized_t, 0.0, 1.0) |
| |
| |
| |
| alpha = args.time_weight_alpha |
| time_weight = torch.exp(-alpha * normalized_t) |
| |
| |
| if time_weight.dim() == 1: |
| |
| time_weight = time_weight.view(-1, 1, 1, 1) |
| |
| |
| combined_weighting = weighting.float() * time_weight |
| else: |
| combined_weighting = weighting.float() |
| |
| reconstruction_loss = torch.mean( |
| (combined_weighting * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), |
| 1, |
| ).mean() |
| |
| |
| |
| log_var = torch.log(torch.clamp(var_output, min=1e-8)) |
| kl_loss = torch.mean(0.5 * (log_var.exp() + log_var - 1.0)) |
| |
| |
| total_loss = reconstruction_loss + args.kl_loss_weight * kl_loss |
| |
| return total_loss, reconstruction_loss, kl_loss |
|
|
|
|
| def main(args): |
| """ |
| 主训练函数 |
| |
| Args: |
| args: 解析后的训练参数 |
| """ |
| |
| logging_dir = str(Path(args.output_dir, args.logging_dir)) |
|
|
| |
| if torch.backends.mps.is_available() and args.mixed_precision == "bf16": |
| raise ValueError("MPS不支持bfloat16混合精度") |
|
|
| |
| accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) |
| kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) |
| accelerator = Accelerator( |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| mixed_precision=args.mixed_precision, |
| log_with=args.report_to, |
| project_config=accelerator_project_config, |
| kwargs_handlers=[kwargs], |
| ) |
|
|
| |
| if torch.cuda.is_available(): |
| |
| local_rank = accelerator.local_process_index |
| num_gpus = torch.cuda.device_count() |
| |
| if local_rank < num_gpus: |
| device_id = local_rank |
| torch.cuda.set_device(device_id) |
| |
| actual_device = torch.cuda.current_device() |
| if actual_device != device_id: |
| logger.warning(f"[Process {accelerator.process_index}] Warning: Requested device {device_id} but got {actual_device}") |
| else: |
| logger.warning(f"[Process {accelerator.process_index}] Warning: local_rank {local_rank} >= num_gpus {num_gpus}, using device 0") |
| torch.cuda.set_device(0) |
|
|
| |
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| level=logging.INFO, |
| ) |
| logger.info(accelerator.state, main_process_only=False) |
| |
| |
| logger.info(f"[Process {accelerator.process_index}] Accelerator initialized:") |
| logger.info(f" - Process index: {accelerator.process_index}") |
| logger.info(f" - Local process index: {accelerator.local_process_index}") |
| logger.info(f" - Device: {accelerator.device}") |
| logger.info(f" - Num processes: {accelerator.num_processes}") |
| logger.info(f" - Is main process: {accelerator.is_main_process}") |
| logger.info(f" - Is local main process: {accelerator.is_local_main_process}") |
| if torch.cuda.is_available(): |
| logger.info(f" - CUDA device count: {torch.cuda.device_count()}") |
| logger.info(f" - Current CUDA device: {torch.cuda.current_device()}") |
| logger.info(f" - CUDA device name: {torch.cuda.get_device_name(torch.cuda.current_device())}") |
| |
| if hasattr(accelerator.device, 'index'): |
| logger.info(f" - Accelerator device index: {accelerator.device.index}") |
| logger.info(f" - Expected device for local_rank {accelerator.local_process_index}: cuda:{accelerator.local_process_index}") |
| |
| if accelerator.is_local_main_process: |
| datasets.utils.logging.set_verbosity_warning() |
| transformers.utils.logging.set_verbosity_warning() |
| diffusers.utils.logging.set_verbosity_info() |
| else: |
| datasets.utils.logging.set_verbosity_error() |
| transformers.utils.logging.set_verbosity_error() |
| diffusers.utils.logging.set_verbosity_error() |
|
|
| |
| if args.seed is not None: |
| set_seed(args.seed) |
|
|
| |
| if accelerator.is_main_process: |
| if args.output_dir is not None: |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| |
| tokenizer_one = CLIPTokenizer.from_pretrained( |
| args.pretrained_model_name_or_path, |
| subfolder="tokenizer", |
| revision=args.revision, |
| ) |
| tokenizer_two = CLIPTokenizer.from_pretrained( |
| args.pretrained_model_name_or_path, |
| subfolder="tokenizer_2", |
| revision=args.revision, |
| ) |
| tokenizer_three = T5TokenizerFast.from_pretrained( |
| args.pretrained_model_name_or_path, |
| subfolder="tokenizer_3", |
| revision=args.revision, |
| ) |
|
|
| |
| text_encoder_cls_one = import_model_class_from_model_name_or_path( |
| args.pretrained_model_name_or_path, args.revision |
| ) |
| text_encoder_cls_two = import_model_class_from_model_name_or_path( |
| args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" |
| ) |
| text_encoder_cls_three = import_model_class_from_model_name_or_path( |
| args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3" |
| ) |
|
|
| |
| |
| noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( |
| args.pretrained_model_name_or_path, subfolder="scheduler" |
| ) |
| noise_scheduler_copy = copy.deepcopy(noise_scheduler) |
| |
| logger.info("Loading models with memory optimization (CPU first)...") |
| text_encoder_one = text_encoder_cls_one.from_pretrained( |
| args.pretrained_model_name_or_path, |
| subfolder="text_encoder", |
| revision=args.revision, |
| variant=args.variant, |
| low_cpu_mem_usage=True, |
| torch_dtype=torch.float32, |
| ) |
| text_encoder_two = text_encoder_cls_two.from_pretrained( |
| args.pretrained_model_name_or_path, |
| subfolder="text_encoder_2", |
| revision=args.revision, |
| variant=args.variant, |
| low_cpu_mem_usage=True, |
| torch_dtype=torch.float32, |
| ) |
| text_encoder_three = text_encoder_cls_three.from_pretrained( |
| args.pretrained_model_name_or_path, |
| subfolder="text_encoder_3", |
| revision=args.revision, |
| variant=args.variant, |
| low_cpu_mem_usage=True, |
| torch_dtype=torch.float32, |
| ) |
| |
| vae = AutoencoderKL.from_pretrained( |
| args.pretrained_model_name_or_path, |
| subfolder="vae", |
| revision=args.revision, |
| variant=args.variant, |
| low_cpu_mem_usage=True, |
| torch_dtype=torch.float32, |
| ) |
| |
| transformer = SD3Transformer2DModel.from_pretrained( |
| args.pretrained_model_name_or_path, |
| subfolder="transformer", |
| revision=args.revision, |
| variant=args.variant, |
| low_cpu_mem_usage=True, |
| torch_dtype=torch.float32, |
| ) |
| vae.requires_grad_(False) |
| text_encoder_one.requires_grad_(False) |
| text_encoder_two.requires_grad_(False) |
| text_encoder_three.requires_grad_(False) |
| |
| |
| logger.info(f"Loading LoRA weights from: {args.lora_model_path}") |
| |
| |
| |
| logger.info("Creating temporary pipeline for LoRA loading (with memory optimization)...") |
| temp_pipeline = StableDiffusion3Pipeline.from_pretrained( |
| args.pretrained_model_name_or_path, |
| vae=vae, |
| text_encoder=text_encoder_one, |
| text_encoder_2=text_encoder_two, |
| text_encoder_3=text_encoder_three, |
| transformer=transformer, |
| revision=args.revision, |
| variant=args.variant, |
| torch_dtype=torch.float32, |
| low_cpu_mem_usage=True, |
| ) |
| |
| |
| lora_loaded = False |
| if check_lora_weights_exist(args.lora_model_path): |
| |
| logger.info("Found standard LoRA weights format, loading...") |
| try: |
| temp_pipeline.load_lora_weights(args.lora_model_path) |
| lora_loaded = True |
| logger.info("Successfully loaded LoRA weights from standard format") |
| except Exception as e: |
| logger.warning(f"Failed to load LoRA from standard format: {e}") |
| logger.info("Trying accelerator checkpoint format...") |
| |
| |
| if not lora_loaded and os.path.isdir(args.lora_model_path): |
| logger.info("Trying to load from accelerator checkpoint format...") |
| |
| detected_rank = None |
| try: |
| from safetensors.torch import load_file |
| model_file = os.path.join(args.lora_model_path, "model.safetensors") |
| if os.path.exists(model_file): |
| state_dict = load_file(model_file) |
| for key, value in state_dict.items(): |
| if 'lora_A' in key and 'transformer' in key and len(value.shape) == 2: |
| detected_rank = value.shape[0] |
| logger.info(f"Detected LoRA rank from checkpoint: {detected_rank}") |
| break |
| except Exception as e: |
| logger.warning(f"Could not detect rank from checkpoint: {e}") |
| |
| |
| lora_rank = detected_rank if detected_rank is not None else 64 |
| lora_loaded = load_lora_from_checkpoint(temp_pipeline, args.lora_model_path, lora_rank=lora_rank) |
| |
| if not lora_loaded: |
| logger.error("Failed to load LoRA from accelerator checkpoint format") |
| raise ValueError("Could not load LoRA weights from the provided path") |
| |
| if not lora_loaded: |
| logger.error("Failed to load LoRA weights. Please check the path and format.") |
| raise ValueError("Invalid LoRA checkpoint.") |
| |
| |
| |
| if hasattr(temp_pipeline.transformer, 'peft_config') and temp_pipeline.transformer.peft_config: |
| logger.info("Fusing LoRA weights into base model...") |
| temp_pipeline.fuse_lora(adapter_names=["default"], lora_scale=1.0) |
| temp_pipeline.unload_lora_weights() |
| logger.info("LoRA weights fused successfully") |
| else: |
| logger.info("LoRA weights already merged (no fusion needed)") |
|
|
| |
| text_encoder_one = temp_pipeline.text_encoder |
| text_encoder_two = temp_pipeline.text_encoder_2 |
| text_encoder_three = temp_pipeline.text_encoder_3 |
| transformer = temp_pipeline.transformer |
|
|
| |
| logger.info("Deleting temporary pipeline to free memory...") |
| del temp_pipeline |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| import gc |
| gc.collect() |
| logger.info("LoRA weights loaded successfully") |
|
|
| |
| input_dim = transformer.config.in_channels |
| |
| if hasattr(transformer.config, 'joint_attention_dim') and transformer.config.joint_attention_dim is not None: |
| sit_hidden_size = transformer.config.joint_attention_dim |
| elif hasattr(transformer.config, 'inner_dim') and transformer.config.inner_dim is not None: |
| sit_hidden_size = transformer.config.inner_dim |
| elif hasattr(transformer.config, 'hidden_size') and transformer.config.hidden_size is not None: |
| sit_hidden_size = transformer.config.hidden_size |
| else: |
| sit_hidden_size = 1536 |
| |
| if hasattr(transformer.config, 'hidden_size') and transformer.config.hidden_size is not None: |
| transformer_hidden_size = transformer.config.hidden_size |
| else: |
| transformer_hidden_size = 1536 |
| |
| if hasattr(transformer.config, 'num_attention_heads'): |
| num_attention_heads = transformer.config.num_attention_heads |
| else: |
| num_attention_heads = 32 |
| |
| rectified_noise_module = RectifiedNoiseModule( |
| hidden_size=sit_hidden_size, |
| num_sit_layers=args.num_sit_layers, |
| num_attention_heads=num_attention_heads, |
| input_dim=input_dim, |
| transformer_hidden_size=transformer_hidden_size |
| ) |
| |
| |
| model = SD3WithRectifiedNoise(transformer, rectified_noise_module) |
| |
| logger.info(f"Created rectified noise model with {args.num_sit_layers} SIT layers") |
|
|
| |
| vae.requires_grad_(False) |
| text_encoder_one.requires_grad_(False) |
| text_encoder_two.requires_grad_(False) |
| text_encoder_three.requires_grad_(False) |
| |
| for param in model.transformer.parameters(): |
| param.requires_grad = False |
| |
| |
| unwrapped_model_temp = model |
| if hasattr(model, 'module'): |
| unwrapped_model_temp = model.module |
| for param in unwrapped_model_temp.rectified_noise_module.parameters(): |
| param.requires_grad = True |
| |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| trainable_ratio = (trainable_params / total_params) * 100 if total_params > 0 else 0 |
| |
| logger.info("Model Parameter Statistics:") |
| logger.info(f" Total parameters: {total_params:,}") |
| logger.info(f" Trainable parameters: {trainable_params:,}") |
| logger.info(f" Trainable ratio: {trainable_ratio:.2f}%") |
| |
| |
| transformer_params = sum(p.numel() for p in model.transformer.parameters()) |
| rectified_noise_params = sum(p.numel() for p in model.rectified_noise_module.parameters()) |
| |
| logger.info("Component Parameter Statistics:") |
| logger.info(f" SD3 Transformer parameters: {transformer_params:,}") |
| logger.info(f" Rectified Noise module parameters: {rectified_noise_params:,}") |
| |
| |
| if hasattr(model.rectified_noise_module, 'input_projection') and model.rectified_noise_module.input_projection is not None: |
| input_proj_params = sum(p.numel() for p in model.rectified_noise_module.input_projection.parameters()) |
| logger.info(f" Input projection parameters: {input_proj_params:,}") |
| |
| sit_blocks_params = sum(p.numel() for p in model.rectified_noise_module.sit_blocks.parameters()) |
| logger.info(f" SIT blocks parameters: {sit_blocks_params:,}") |
| |
| dual_output_params = sum(p.numel() for p in model.rectified_noise_module.dual_output_layer.parameters()) |
| logger.info(f" Dual output layer parameters: {dual_output_params:,}") |
| |
| output_proj_params = sum(p.numel() for p in model.rectified_noise_module.output_projection.parameters()) |
| logger.info(f" Output projection parameters: {output_proj_params:,}") |
|
|
| logger.info("Frozen all non-SIT parameters") |
|
|
| |
| weight_dtype = torch.float32 |
| if accelerator.mixed_precision == "fp16": |
| weight_dtype = torch.float16 |
| elif accelerator.mixed_precision == "bf16": |
| weight_dtype = torch.bfloat16 |
|
|
| |
| |
| if torch.cuda.is_available(): |
| |
| target_device_id = accelerator.local_process_index % torch.cuda.device_count() |
| target_device = torch.device(f"cuda:{target_device_id}") |
| torch.cuda.set_device(target_device_id) |
| else: |
| target_device = accelerator.device |
| |
| logger.info(f"[Process {accelerator.process_index}] Moving models to device: {target_device}") |
| logger.info(f"[Process {accelerator.process_index}] Accelerator state: num_processes={accelerator.num_processes}, local_process_index={accelerator.local_process_index}") |
| logger.info(f"[Process {accelerator.process_index}] Target device: {target_device}, Current CUDA device: {torch.cuda.current_device() if torch.cuda.is_available() else 'N/A'}") |
| |
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| import gc |
| gc.collect() |
| |
| |
| logger.info(f"[Process {accelerator.process_index}] Moving VAE to device...") |
| vae.to(target_device, dtype=torch.float32) |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| logger.info(f"[Process {accelerator.process_index}] Moving text encoders to device...") |
| text_encoder_one.to(target_device, dtype=weight_dtype) |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| text_encoder_two.to(target_device, dtype=weight_dtype) |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| text_encoder_three.to(target_device, dtype=weight_dtype) |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| logger.info(f"[Process {accelerator.process_index}] Moving main model to device...") |
| model.to(target_device, dtype=weight_dtype) |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| import gc |
| gc.collect() |
| |
| |
| logger.info(f"[Process {accelerator.process_index}] Device verification:") |
| logger.info(f" - VAE device: {next(vae.parameters()).device}") |
| logger.info(f" - Model device: {next(model.parameters()).device}") |
| logger.info(f" - Text encoder 1 device: {next(text_encoder_one.parameters()).device}") |
| logger.info(f" - Text encoder 2 device: {next(text_encoder_two.parameters()).device}") |
| logger.info(f" - Text encoder 3 device: {next(text_encoder_three.parameters()).device}") |
| logger.info(f" - CUDA visible devices: {torch.cuda.device_count()} devices available") |
| if torch.cuda.is_available(): |
| logger.info(f" - Current CUDA device: {torch.cuda.current_device()}") |
| logger.info(f" - Device name: {torch.cuda.get_device_name(target_device_id)}") |
|
|
| def unwrap_model(model): |
| model = accelerator.unwrap_model(model) |
| model = model._orig_mod if hasattr(model, '_orig_mod') else model |
| return model |
|
|
| |
| if args.allow_tf32 and torch.cuda.is_available(): |
| torch.backends.cuda.matmul.allow_tf32 = True |
|
|
| |
| if args.scale_lr: |
| args.sit_learning_rate = ( |
| args.sit_learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes |
| ) |
|
|
| |
| unwrapped_model = unwrap_model(model) if hasattr(model, 'module') else model |
| sit_parameters = list(filter(lambda p: p.requires_grad, unwrapped_model.rectified_noise_module.parameters())) |
| |
| if args.use_8bit_adam: |
| try: |
| import bitsandbytes as bnb |
| except ImportError: |
| raise ImportError("To use 8-bit Adam, install bitsandbytes: pip install bitsandbytes") |
| optimizer_class = bnb.optim.AdamW8bit |
| else: |
| optimizer_class = torch.optim.AdamW |
|
|
| optimizer = optimizer_class( |
| sit_parameters, |
| lr=args.sit_learning_rate, |
| betas=(args.adam_beta1, args.adam_beta2), |
| weight_decay=args.adam_weight_decay, |
| eps=args.adam_epsilon, |
| ) |
| |
| logger.info(f"Created optimizer for {len(sit_parameters)} SIT parameters") |
|
|
| |
| |
| logger.info("Loading dataset...") |
| |
| if args.train_data_dir is None: |
| raise ValueError("train_data_dir must be provided") |
| |
| if not os.path.exists(args.train_data_dir): |
| raise ValueError(f"Training data directory does not exist: {args.train_data_dir}") |
| |
| with accelerator.main_process_first(): |
| metadata_path = None |
| if args.train_data_dir is not None: |
| |
| potential_metadata = os.path.join(args.train_data_dir, "metadata.jsonl") |
| if os.path.exists(potential_metadata): |
| metadata_path = potential_metadata |
| |
| if metadata_path is not None: |
| |
| if accelerator.is_main_process: |
| logger.info(f"Found metadata.jsonl, using efficient loading method") |
| dataset = load_dataset_from_jsonl(metadata_path, args.train_data_dir, accelerator) |
| elif args.dataset_name is not None: |
| |
| dataset = load_dataset( |
| args.dataset_name, |
| args.dataset_config_name, |
| cache_dir=args.cache_dir, |
| data_dir=args.train_data_dir |
| ) |
| else: |
| |
| if accelerator.is_main_process: |
| logger.warning("No metadata.jsonl found, using imagefolder (may be slow for large datasets)") |
| try: |
| dataset = load_dataset( |
| "imagefolder", |
| data_dir=args.train_data_dir, |
| cache_dir=args.cache_dir, |
| ) |
| except Exception as e: |
| logger.error(f"Failed to load imagefolder dataset: {e}") |
| |
| logger.info("Using data_files fallback...") |
| data_files = {} |
| data_files["train"] = os.path.join(args.train_data_dir, "**") |
| dataset = load_dataset( |
| "imagefolder", |
| data_files=data_files, |
| cache_dir=args.cache_dir, |
| ) |
| |
| if accelerator.is_main_process: |
| logger.info("Dataset loaded successfully.") |
| |
| |
| accelerator.wait_for_everyone() |
| |
| if accelerator.is_main_process: |
| logger.info("All processes synchronized. Building transforms and DataLoader...") |
| |
| |
| train_dataset = None |
| if isinstance(dataset, datasets.DatasetDict): |
| |
| if "train" in dataset: |
| train_dataset = dataset["train"] |
| else: |
| |
| splits = list(dataset.keys()) |
| if splits: |
| train_dataset = dataset[splits[0]] |
| else: |
| raise ValueError("No splits found in dataset") |
| elif isinstance(dataset, datasets.Dataset): |
| |
| train_dataset = dataset |
| else: |
| raise ValueError(f"Unexpected dataset type: {type(dataset)}") |
|
|
| |
| column_names = train_dataset.column_names |
| dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) |
| |
| if accelerator.is_main_process: |
| logger.info(f"Dataset columns: {column_names}") |
| |
| |
| if args.image_column is not None and args.image_column in column_names: |
| |
| image_column = args.image_column |
| else: |
| |
| if 'image' in column_names: |
| image_column = 'image' |
| else: |
| image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] |
| |
| |
| if args.image_column is not None and args.image_column != image_column: |
| if accelerator.is_main_process: |
| logger.warning(f"Specified image_column '{args.image_column}' not found. Using '{image_column}' instead.") |
| |
| if accelerator.is_main_process: |
| logger.info(f"Using image column: {image_column}") |
| |
| |
| if args.caption_column is not None and args.caption_column in column_names: |
| |
| caption_column = args.caption_column |
| else: |
| |
| if 'text' in column_names: |
| caption_column = 'text' |
| elif 'caption' in column_names: |
| caption_column = 'caption' |
| else: |
| caption_column = dataset_columns[1] if dataset_columns is not None else (column_names[1] if len(column_names) > 1 else column_names[0]) |
| |
| |
| if args.caption_column is not None and args.caption_column != caption_column: |
| if accelerator.is_main_process: |
| logger.warning(f"Specified caption_column '{args.caption_column}' not found. Using '{caption_column}' instead.") |
| |
| if accelerator.is_main_process: |
| logger.info(f"Using caption column: {caption_column}") |
| def tokenize_captions(examples, is_train=True): |
| captions = [] |
| for caption in examples[caption_column]: |
| if isinstance(caption, str): |
| captions.append(caption) |
| elif isinstance(caption, (list, np.ndarray)): |
| captions.append(random.choice(caption) if is_train else caption[0]) |
| else: |
| raise ValueError("Caption column should contain strings or lists of strings.") |
| |
| tokens_one = tokenize_prompt(tokenizer_one, captions) |
| tokens_two = tokenize_prompt(tokenizer_two, captions) |
| tokens_three = tokenize_prompt(tokenizer_three, captions) |
| return tokens_one, tokens_two, tokens_three |
|
|
| |
| train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) |
| train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) |
| train_flip = transforms.RandomHorizontalFlip(p=1.0) |
| train_transforms = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize([0.5], [0.5]), |
| ]) |
|
|
| def preprocess_train(examples): |
| |
| images = [] |
| for img in examples[image_column]: |
| if isinstance(img, str): |
| |
| try: |
| img = Image.open(img).convert("RGB") |
| except Exception as e: |
| |
| if accelerator.is_main_process and len(images) < 5: |
| logger.warning(f"Failed to load image {img}: {e}") |
| img = Image.new('RGB', (args.resolution, args.resolution), color='black') |
| elif hasattr(img, 'convert'): |
| |
| img = img.convert("RGB") |
| else: |
| raise ValueError(f"Unexpected image type: {type(img)}") |
| images.append(img) |
| |
| original_sizes = [] |
| all_images = [] |
| crop_top_lefts = [] |
| |
| for image in images: |
| original_sizes.append((image.height, image.width)) |
| image = train_resize(image) |
| if args.random_flip and random.random() < 0.5: |
| image = train_flip(image) |
| if args.center_crop: |
| y1 = max(0, int(round((image.height - args.resolution) / 2.0))) |
| x1 = max(0, int(round((image.width - args.resolution) / 2.0))) |
| image = train_crop(image) |
| else: |
| y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) |
| image = crop(image, y1, x1, h, w) |
| crop_top_left = (y1, x1) |
| crop_top_lefts.append(crop_top_left) |
| image = train_transforms(image) |
| all_images.append(image) |
|
|
| examples["original_sizes"] = original_sizes |
| examples["crop_top_lefts"] = crop_top_lefts |
| examples["pixel_values"] = all_images |
| |
| tokens_one, tokens_two, tokens_three = tokenize_captions(examples) |
| examples["input_ids_one"] = tokens_one |
| examples["input_ids_two"] = tokens_two |
| examples["input_ids_three"] = tokens_three |
| return examples |
|
|
| with accelerator.main_process_first(): |
| |
| if args.max_train_samples is not None: |
| |
| if hasattr(train_dataset, 'shuffle') and hasattr(train_dataset, 'select'): |
| train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples)) |
| |
| if hasattr(train_dataset, 'with_transform'): |
| train_dataset = train_dataset.with_transform(preprocess_train, output_all_columns=True) |
| def collate_fn(examples): |
| pixel_values = torch.stack([example["pixel_values"] for example in examples]) |
| pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() |
| original_sizes = [example["original_sizes"] for example in examples] |
| crop_top_lefts = [example["crop_top_lefts"] for example in examples] |
| input_ids_one = torch.stack([example["input_ids_one"] for example in examples]) |
| input_ids_two = torch.stack([example["input_ids_two"] for example in examples]) |
| input_ids_three = torch.stack([example["input_ids_three"] for example in examples]) |
| |
| return { |
| "pixel_values": pixel_values, |
| "input_ids_one": input_ids_one, |
| "input_ids_two": input_ids_two, |
| "input_ids_three": input_ids_three, |
| "original_sizes": original_sizes, |
| "crop_top_lefts": crop_top_lefts, |
| } |
|
|
| |
| if args.dataloader_num_workers == 0 and accelerator.num_processes > 1: |
| args.dataloader_num_workers = min(4, os.cpu_count() // accelerator.num_processes) |
| logger.info(f"Auto-setting dataloader_num_workers to {args.dataloader_num_workers}") |
|
|
| train_dataloader = torch.utils.data.DataLoader( |
| train_dataset, |
| shuffle=True, |
| collate_fn=collate_fn, |
| batch_size=args.train_batch_size, |
| num_workers=args.dataloader_num_workers, |
| pin_memory=True, |
| persistent_workers=args.dataloader_num_workers > 0, |
| ) |
|
|
| |
| overrode_max_train_steps = False |
| |
| try: |
| train_dataloader_len = len(train_dataloader) |
| except: |
| train_dataloader_len = 1 |
| num_update_steps_per_epoch = math.ceil(train_dataloader_len / args.gradient_accumulation_steps) |
| if args.max_train_steps is None: |
| args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| overrode_max_train_steps = True |
|
|
| lr_scheduler = get_scheduler( |
| args.lr_scheduler, |
| optimizer=optimizer, |
| num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
| num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| ) |
|
|
| |
| model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( |
| model, optimizer, train_dataloader, lr_scheduler |
| ) |
|
|
| |
| if args.gradient_checkpointing: |
| unwrap_model(model).rectified_noise_module.gradient_checkpointing_enable() |
|
|
| |
| try: |
| train_dataloader_len = len(train_dataloader) |
| except: |
| train_dataloader_len = 1 |
| num_update_steps_per_epoch = math.ceil(train_dataloader_len / args.gradient_accumulation_steps) |
| if overrode_max_train_steps: |
| args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
|
|
| |
| if accelerator.is_main_process: |
| try: |
| accelerator.init_trackers("rectified-noise-fine-tune", config=vars(args)) |
| except Exception as e: |
| logger.warning(f"Failed to initialize trackers: {e}") |
| args.report_to = None |
|
|
| |
| total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
| logger.info("***** Running Rectified Noise training *****") |
| |
| try: |
| train_dataset_len = len(train_dataset) |
| except: |
| train_dataset_len = 0 |
| logger.info(f" Num examples = {train_dataset_len}") |
| logger.info(f" Num Epochs = {args.num_train_epochs}") |
| logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") |
| logger.info(f" Total train batch size = {total_batch_size}") |
| logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
| logger.info(f" Total optimization steps = {args.max_train_steps}") |
| logger.info(f" Number of GPU processes = {accelerator.num_processes}") |
| logger.info(f" Number of SIT layers = {args.num_sit_layers}") |
| logger.info(f" SIT learning rate = {args.sit_learning_rate}") |
| logger.info(f" KL loss weight = {args.kl_loss_weight}") |
| |
| global_step = 0 |
| first_epoch = 0 |
|
|
| |
| if args.resume_from_checkpoint: |
| if args.resume_from_checkpoint != "latest": |
| path = os.path.basename(args.resume_from_checkpoint) |
| else: |
| dirs = os.listdir(args.output_dir) |
| dirs = [d for d in dirs if d.startswith("checkpoint")] |
| dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) |
| path = dirs[-1] if len(dirs) > 0 else None |
|
|
| if path is None: |
| accelerator.print(f"Checkpoint '{args.resume_from_checkpoint}' does not exist.") |
| args.resume_from_checkpoint = None |
| initial_global_step = 0 |
| else: |
| accelerator.print(f"Resuming from checkpoint {path}") |
| accelerator.load_state(os.path.join(args.output_dir, path)) |
| global_step = int(path.split("-")[1]) |
| initial_global_step = global_step |
| first_epoch = global_step // num_update_steps_per_epoch |
| else: |
| initial_global_step = 0 |
|
|
| progress_bar = tqdm( |
| range(0, args.max_train_steps), |
| initial=initial_global_step, |
| desc="Steps", |
| disable=not accelerator.is_local_main_process, |
| ) |
|
|
| def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): |
| |
| try: |
| if hasattr(noise_scheduler_copy, 'sigmas'): |
| sigmas_attr = getattr(noise_scheduler_copy, 'sigmas', None) |
| if sigmas_attr is not None and hasattr(sigmas_attr, 'to'): |
| sigmas = sigmas_attr.to(device=accelerator.device, dtype=dtype) |
| else: |
| sigmas = torch.tensor([1.0], device=accelerator.device, dtype=dtype) |
| else: |
| sigmas = torch.tensor([1.0], device=accelerator.device, dtype=dtype) |
| |
| if hasattr(noise_scheduler_copy, 'timesteps'): |
| timesteps_attr = getattr(noise_scheduler_copy, 'timesteps', None) |
| if timesteps_attr is not None and hasattr(timesteps_attr, 'to'): |
| schedule_timesteps = timesteps_attr.to(accelerator.device) |
| else: |
| schedule_timesteps = torch.tensor([1.0], device=accelerator.device) |
| else: |
| schedule_timesteps = torch.tensor([1.0], device=accelerator.device) |
| except: |
| sigmas = torch.tensor([1.0], device=accelerator.device, dtype=dtype) |
| schedule_timesteps = torch.tensor([1.0], device=accelerator.device) |
| timesteps = timesteps.to(accelerator.device) |
| step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] |
| sigma = sigmas[step_indices].flatten() |
| while len(sigma.shape) < n_dim: |
| sigma = sigma.unsqueeze(-1) |
| return sigma |
|
|
| |
| for epoch in range(first_epoch, args.num_train_epochs): |
| model.train() |
| |
| train_loss = 0.0 |
| reconstruction_loss_total = 0.0 |
| kl_loss_total = 0.0 |
| |
| for step, batch in enumerate(train_dataloader): |
| with accelerator.accumulate(model): |
| |
| pixel_values = batch["pixel_values"].to(dtype=vae.dtype) |
| |
| try: |
| encoded_result = vae.encode(pixel_values) |
| if hasattr(encoded_result, 'latent_dist'): |
| model_input = encoded_result.latent_dist.sample() |
| else: |
| |
| model_input = encoded_result.sample() if hasattr(encoded_result, 'sample') else encoded_result |
| except: |
| |
| model_input = torch.randn(pixel_values.shape[0], 4, pixel_values.shape[2]//8, pixel_values.shape[3]//8, device=pixel_values.device, dtype=pixel_values.dtype) |
| |
| |
| |
| try: |
| vae_config_shift_factor = getattr(vae.config, 'shift_factor', 0.0) if hasattr(vae, 'config') else 0.0 |
| vae_config_scaling_factor = getattr(vae.config, 'scaling_factor', 1.0) if hasattr(vae, 'config') else 1.0 |
| except: |
| vae_config_shift_factor = 0.0 |
| vae_config_scaling_factor = 1.0 |
| |
| |
| if not isinstance(model_input, torch.Tensor): |
| model_input = torch.tensor(model_input, device=pixel_values.device, dtype=pixel_values.dtype) |
| model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor |
| model_input = model_input.to(dtype=weight_dtype) |
|
|
| |
| prompt_embeds, pooled_prompt_embeds = encode_prompt( |
| text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], |
| tokenizers=[tokenizer_one, tokenizer_two, tokenizer_three], |
| prompt=None, |
| max_sequence_length=args.max_sequence_length, |
| text_input_ids_list=[batch["input_ids_one"], batch["input_ids_two"], batch["input_ids_three"]], |
| ) |
|
|
| |
| noise = torch.randn_like(model_input) |
| bsz = model_input.shape[0] |
| |
| |
| u = compute_density_for_timestep_sampling( |
| weighting_scheme=args.weighting_scheme, |
| batch_size=bsz, |
| logit_mean=args.logit_mean, |
| logit_std=args.logit_std, |
| mode_scale=args.mode_scale, |
| ) |
| |
| try: |
| num_train_timesteps = getattr(noise_scheduler_copy.config, 'num_train_timesteps', 1000) if hasattr(noise_scheduler_copy, 'config') else 1000 |
| except: |
| num_train_timesteps = 1000 |
| |
| indices = (u * num_train_timesteps).long() |
| |
| try: |
| if hasattr(noise_scheduler_copy, 'timesteps'): |
| timesteps_attr = getattr(noise_scheduler_copy, 'timesteps', None) |
| if timesteps_attr is not None: |
| timesteps = timesteps_attr[indices].to(device=model_input.device) |
| else: |
| timesteps = torch.randint(0, num_train_timesteps, (bsz,), device=model_input.device).long() |
| else: |
| timesteps = torch.randint(0, num_train_timesteps, (bsz,), device=model_input.device).long() |
| except: |
| timesteps = torch.randint(0, num_train_timesteps, (bsz,), device=model_input.device).long() |
| |
| sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) |
| noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise |
| |
| |
| model_output = model( |
| hidden_states=noisy_model_input, |
| timestep=timesteps, |
| encoder_hidden_states=prompt_embeds, |
| pooled_projections=pooled_prompt_embeds, |
| return_dict=True, |
| ) |
| |
| model_pred = model_output["sample"] |
| mean_output = model_output["mean_output"] |
| var_output = model_output["var_output"] |
|
|
| |
| if args.precondition_outputs: |
| model_pred = model_pred * (-sigmas) + noisy_model_input |
| target = model_input |
| else: |
| target = noise - model_input |
|
|
| |
| loss, reconstruction_loss, kl_loss = compute_loss( |
| model_pred, target, mean_output, var_output, args, sigmas, timesteps |
| ) |
|
|
| |
| try: |
| gathered_loss = accelerator.gather(loss.repeat(args.train_batch_size)) |
| avg_loss = torch.mean(gathered_loss) if isinstance(gathered_loss, torch.Tensor) else gathered_loss.mean() |
| except: |
| avg_loss = loss |
| |
| try: |
| gathered_recon_loss = accelerator.gather(reconstruction_loss.repeat(args.train_batch_size)) |
| avg_recon_loss = torch.mean(gathered_recon_loss) if isinstance(gathered_recon_loss, torch.Tensor) else gathered_recon_loss.mean() |
| except: |
| avg_recon_loss = reconstruction_loss |
| |
| try: |
| gathered_kl_loss = accelerator.gather(kl_loss.repeat(args.train_batch_size)) |
| avg_kl_loss = torch.mean(gathered_kl_loss) if isinstance(gathered_kl_loss, torch.Tensor) else gathered_kl_loss.mean() |
| except: |
| avg_kl_loss = kl_loss |
| train_loss += avg_loss.item() / args.gradient_accumulation_steps |
| reconstruction_loss_total += avg_recon_loss.item() / args.gradient_accumulation_steps |
| kl_loss_total += avg_kl_loss.item() / args.gradient_accumulation_steps |
|
|
| |
| accelerator.backward(loss) |
| if accelerator.sync_gradients: |
| if args.max_grad_norm is not None and args.max_grad_norm > 0: |
| accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) |
| |
| optimizer.step() |
| lr_scheduler.step() |
| optimizer.zero_grad() |
|
|
| |
| if accelerator.sync_gradients: |
| progress_bar.update(1) |
| global_step += 1 |
| if hasattr(accelerator, 'trackers') and accelerator.trackers: |
| accelerator.log({ |
| "train_loss": train_loss, |
| "reconstruction_loss": reconstruction_loss_total, |
| "kl_loss": kl_loss_total, |
| "learning_rate": lr_scheduler.get_last_lr()[0] |
| }, step=global_step) |
| train_loss = 0.0 |
| reconstruction_loss_total = 0.0 |
| kl_loss_total = 0.0 |
|
|
| |
| if accelerator.is_main_process and args.validation_prompt and args.num_validation_images > 0 and args.validation_steps > 0: |
| if global_step % args.validation_steps == 0: |
| logger.info(f"Running validation... \n Generating {args.num_validation_images} images with prompt: {args.validation_prompt}") |
| |
| validation_pipeline = StableDiffusion3Pipeline.from_pretrained( |
| args.pretrained_model_name_or_path, |
| vae=vae, |
| text_encoder=text_encoder_one, |
| text_encoder_2=text_encoder_two, |
| text_encoder_3=text_encoder_three, |
| transformer=unwrap_model(model).transformer, |
| revision=args.revision, |
| variant=args.variant, |
| torch_dtype=weight_dtype, |
| ) |
| validation_pipeline.model=model |
| |
| |
| |
| |
| images = log_validation(validation_pipeline, args, accelerator, epoch, global_step=global_step) |
| del validation_pipeline |
|
|
| |
| if accelerator.is_main_process: |
| if global_step % args.checkpointing_steps == 0: |
| |
| if args.checkpoints_total_limit is not None: |
| removed_count, remaining_count = manage_checkpoints( |
| args.output_dir, args.checkpoints_total_limit |
| ) |
| if removed_count > 0: |
| logger.info(f"清理检查点: 删除了 {removed_count} 个旧检查点") |
|
|
| |
| save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") |
| |
| try: |
| accelerator.save_state(save_path) |
| logger.info(f"保存检查点到: {save_path}") |
| |
| |
| if args.save_sit_weights_only: |
| sit_save_path = os.path.join(save_path, "sit_weights") |
| save_sit_weights(unwrap_model(model), sit_save_path) |
| logger.info(f"保存 SIT 权重到: {sit_save_path}") |
| |
| except Exception as e: |
| logger.error(f"保存检查点失败: {e}") |
| if os.path.exists(save_path): |
| try: |
| shutil.rmtree(save_path) |
| logger.info(f"清理失败的检查点目录: {save_path}") |
| except: |
| pass |
| logs = { |
| "step_loss": loss.detach().item(), |
| "recon_loss": reconstruction_loss.detach().item(), |
| "kl_loss": kl_loss.detach().item(), |
| "lr": lr_scheduler.get_last_lr()[0] |
| } |
| progress_bar.set_postfix(**logs) |
|
|
| if global_step >= args.max_train_steps: |
| break |
|
|
| |
| accelerator.wait_for_everyone() |
| if accelerator.is_main_process: |
| final_model = unwrap_model(model) |
| |
| |
| final_sit_weights_path = os.path.join(args.output_dir, "sit_weights") |
| save_sit_weights(final_model, final_sit_weights_path) |
| |
| logger.info("="*60) |
| logger.info("训练已完成!") |
| logger.info(f"SIT 权重已保存到: {final_sit_weights_path}") |
| logger.info(f"检查点保存在: {args.output_dir}/checkpoint-*") |
| logger.info("="*60) |
|
|
| accelerator.end_training() |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| main(args) |