| |
| |
| """ |
| 分布式采样脚本:支持指定 LoRA 权重与 Rectified Noise(SIT) 权重 |
| |
| 依据 train_rectified_noise.py 的模型结构,加载并组装 SD3WithRectifiedNoise 进行采样。 |
| """ |
|
|
| import os |
| import sys |
| import json |
| import math |
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| import torch.distributed as dist |
| from tqdm import tqdm |
| import numpy as np |
| from PIL import Image |
|
|
| from accelerate import Accelerator |
| from diffusers import StableDiffusion3Pipeline |
| from peft import LoraConfig, get_peft_model_state_dict |
| from peft.utils import set_peft_model_state_dict |
|
|
|
|
| def dynamic_import_training_classes(project_root: str): |
| """从 train_rectified_noise.py 动态导入 RectifiedNoiseModule 和 SD3WithRectifiedNoise""" |
| sys.path.insert(0, project_root) |
| try: |
| import train_rectified_noise as trn |
| return trn.RectifiedNoiseModule, trn.SD3WithRectifiedNoise |
| except Exception as e: |
| raise ImportError(f"无法从 train_rectified_noise.py 导入类: {e}") |
|
|
| def create_npz_from_sample_folder(sample_dir, num_samples): |
| """ |
| 从样本文件夹构建单个.npz文件,保持与sample_ddp_new相同的格式 |
| """ |
| samples = [] |
| actual_files = [] |
| |
| |
| for filename in sorted(os.listdir(sample_dir)): |
| if filename.endswith('.png'): |
| actual_files.append(filename) |
| |
| |
| for i in tqdm(range(min(num_samples, len(actual_files))), desc="Building .npz file from samples"): |
| if i < len(actual_files): |
| sample_path = os.path.join(sample_dir, actual_files[i]) |
| sample_pil = Image.open(sample_path) |
| sample_np = np.asarray(sample_pil).astype(np.uint8) |
| samples.append(sample_np) |
| else: |
| |
| sample_np = np.zeros((512, 512, 3), dtype=np.uint8) |
| samples.append(sample_np) |
| |
| if samples: |
| samples = np.stack(samples) |
| npz_path = f"{sample_dir}.npz" |
| np.savez(npz_path, arr_0=samples) |
| print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") |
| return npz_path |
| else: |
| print("No samples found to create npz file.") |
| return None |
|
|
|
|
| def get_existing_sample_count(sample_dir): |
| """获取已存在的样本数量和最大索引""" |
| if not os.path.exists(sample_dir): |
| return 0, -1 |
| |
| existing_files = [] |
| for filename in os.listdir(sample_dir): |
| if filename.endswith('.png') and filename[:-4].isdigit(): |
| try: |
| idx = int(filename[:-4]) |
| existing_files.append(idx) |
| except ValueError: |
| continue |
| |
| if not existing_files: |
| return 0, -1 |
| |
| existing_files.sort() |
| max_index = existing_files[-1] |
| count = len(existing_files) |
| |
| |
| expected_count = max_index + 1 |
| if count < expected_count: |
| print(f"Warning: Found {count} files but expected {expected_count} (missing some indices)") |
| |
| return count, max_index |
|
|
|
|
|
|
| def load_sit_weights(rectified_module, weights_path: str, rank=0): |
| """加载 Rectified Noise(SIT) 权重,支持 .safetensors / .bin / .pt |
| 支持以下目录结构: |
| - weights_path/pytorch_sit_weights.safetensors (直接在主目录) |
| - weights_path/sit_weights/pytorch_sit_weights.safetensors (在sit_weights子目录) |
| """ |
| if os.path.isdir(weights_path): |
| |
| search_paths = [ |
| weights_path, |
| os.path.join(weights_path, "sit_weights"), |
| ] |
| |
| for search_dir in search_paths: |
| if not os.path.exists(search_dir): |
| continue |
| |
| |
| st_path = os.path.join(search_dir, "pytorch_sit_weights.safetensors") |
| if os.path.exists(st_path): |
| try: |
| from safetensors.torch import load_file |
| if rank == 0: |
| print(f"Loading rectified weights from: {st_path}") |
| state = load_file(st_path) |
| missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False) |
| if rank == 0: |
| print(f" Loaded rectified weights: {len(state)} keys") |
| if missing_keys: |
| print(f" Missing keys: {len(missing_keys)}") |
| if unexpected_keys: |
| print(f" Unexpected keys: {len(unexpected_keys)}") |
| return True |
| except Exception as e: |
| if rank == 0: |
| print(f" Failed to load from {st_path}: {e}") |
| continue |
| |
| |
| for name in ["pytorch_sit_weights.bin", "pytorch_sit_weights.pt", "sit_weights.pt", "sit.pt"]: |
| cand = os.path.join(search_dir, name) |
| if os.path.exists(cand): |
| try: |
| if rank == 0: |
| print(f"Loading rectified weights from: {cand}") |
| state = torch.load(cand, map_location="cpu") |
| missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False) |
| if rank == 0: |
| print(f" Loaded rectified weights: {len(state)} keys") |
| if missing_keys: |
| print(f" Missing keys: {len(missing_keys)}") |
| if unexpected_keys: |
| print(f" Unexpected keys: {len(unexpected_keys)}") |
| return True |
| except Exception as e: |
| if rank == 0: |
| print(f" Failed to load from {cand}: {e}") |
| continue |
| |
| |
| try: |
| for fn in os.listdir(search_dir): |
| if fn.endswith((".pt", ".bin")): |
| cand = os.path.join(search_dir, fn) |
| try: |
| if rank == 0: |
| print(f"Loading rectified weights from: {cand}") |
| state = torch.load(cand, map_location="cpu") |
| missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False) |
| if rank == 0: |
| print(f" Loaded rectified weights: {len(state)} keys") |
| return True |
| except Exception as e: |
| if rank == 0: |
| print(f" Failed to load from {cand}: {e}") |
| continue |
| except Exception: |
| pass |
| |
| if rank == 0: |
| print(f" ❌ No rectified weights found in {weights_path} or {os.path.join(weights_path, 'sit_weights')}") |
| return False |
| else: |
| |
| try: |
| if rank == 0: |
| print(f"Loading rectified weights from file: {weights_path}") |
| if weights_path.endswith(".safetensors"): |
| from safetensors.torch import load_file |
| state = load_file(weights_path) |
| else: |
| state = torch.load(weights_path, map_location="cpu") |
| missing_keys, unexpected_keys = rectified_module.load_state_dict(state, strict=False) |
| if rank == 0: |
| print(f" Loaded rectified weights: {len(state)} keys") |
| if missing_keys: |
| print(f" Missing keys: {len(missing_keys)}") |
| if unexpected_keys: |
| print(f" Unexpected keys: {len(unexpected_keys)}") |
| return True |
| except Exception as e: |
| if rank == 0: |
| print(f" ❌ Failed to load rectified weights from {weights_path}: {e}") |
| return False |
|
|
|
|
| def check_lora_weights_exist(lora_path): |
| """检查LoRA权重文件是否存在""" |
| if not lora_path: |
| return False |
| |
| if os.path.isdir(lora_path): |
| |
| weight_file = os.path.join(lora_path, "pytorch_lora_weights.safetensors") |
| if os.path.exists(weight_file): |
| return True |
| |
| for file in os.listdir(lora_path): |
| if file.endswith(".safetensors") and "lora" in file.lower(): |
| return True |
| return False |
| elif os.path.isfile(lora_path): |
| return lora_path.endswith(".safetensors") |
| |
| return False |
|
|
|
|
| def load_lora_from_checkpoint(pipeline, checkpoint_path, rank=0, lora_rank=64): |
| """ |
| 从accelerator checkpoint目录加载LoRA权重或完整模型权重 |
| 如果checkpoint包含完整的模型权重(合并后的),直接加载 |
| 如果只包含LoRA权重,则按LoRA方式加载 |
| """ |
| if rank == 0: |
| print(f"Loading weights from accelerator checkpoint: {checkpoint_path}") |
| |
| try: |
| from safetensors.torch import load_file |
| model_file = os.path.join(checkpoint_path, "model.safetensors") |
| if not os.path.exists(model_file): |
| if rank == 0: |
| print(f"Model file not found: {model_file}") |
| return False |
| |
| |
| state_dict = load_file(model_file) |
| all_keys = list(state_dict.keys()) |
| |
| |
| |
| |
| |
| lora_keys = [k for k in all_keys if 'lora' in k.lower() and 'transformer' in k.lower()] |
| base_layer_keys = [k for k in all_keys if 'base_layer' in k.lower() and 'transformer' in k.lower()] |
| non_lora_transformer_keys = [k for k in all_keys if 'lora' not in k.lower() and 'base_layer' not in k.lower() and 'transformer' in k.lower()] |
| |
| if rank == 0: |
| print(f"Checkpoint analysis:") |
| print(f" Total keys: {len(all_keys)}") |
| print(f" LoRA keys: {len(lora_keys)}") |
| print(f" Base layer keys: {len(base_layer_keys)}") |
| print(f" Direct transformer weight keys (merged): {len(non_lora_transformer_keys)}") |
| |
| |
| if len(base_layer_keys) > 0: |
| if rank == 0: |
| print(f"✓ Detected PEFT format (base_layer + LoRA), merging weights...") |
| |
| |
| merged_state_dict = {} |
| |
| |
| modules_to_merge = {} |
| |
| non_lora_keys_found = [] |
| |
| for key in all_keys: |
| |
| new_key = key |
| has_transformer_prefix = False |
| |
| if key.startswith('base_model.model.transformer.'): |
| new_key = key[len('base_model.model.transformer.'):] |
| has_transformer_prefix = True |
| elif key.startswith('model.transformer.'): |
| new_key = key[len('model.transformer.'):] |
| has_transformer_prefix = True |
| elif key.startswith('transformer.'): |
| new_key = key[len('transformer.'):] |
| has_transformer_prefix = True |
| elif 'transformer' in key.lower(): |
| |
| has_transformer_prefix = True |
| |
| if not has_transformer_prefix: |
| continue |
| |
| |
| if '.base_layer.weight' in new_key: |
| |
| module_key = new_key.replace('.base_layer.weight', '.weight') |
| if module_key not in modules_to_merge: |
| modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None} |
| modules_to_merge[module_key]['base_weight'] = (key, state_dict[key]) |
| elif '.base_layer.bias' in new_key: |
| module_key = new_key.replace('.base_layer.bias', '.bias') |
| if module_key not in modules_to_merge: |
| modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None} |
| modules_to_merge[module_key]['base_bias'] = (key, state_dict[key]) |
| elif '.lora_A.default.weight' in new_key: |
| module_key = new_key.replace('.lora_A.default.weight', '.weight') |
| if module_key not in modules_to_merge: |
| modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None} |
| modules_to_merge[module_key]['lora_A'] = (key, state_dict[key]) |
| elif '.lora_B.default.weight' in new_key: |
| module_key = new_key.replace('.lora_B.default.weight', '.weight') |
| if module_key not in modules_to_merge: |
| modules_to_merge[module_key] = {'base_weight': None, 'base_bias': None, 'lora_A': None, 'lora_B': None} |
| modules_to_merge[module_key]['lora_B'] = (key, state_dict[key]) |
| elif 'lora' not in new_key.lower() and 'base_layer' not in new_key.lower(): |
| |
| |
| merged_state_dict[new_key] = state_dict[key] |
| non_lora_keys_found.append(new_key) |
| |
| if rank == 0: |
| print(f" Found {len(non_lora_keys_found)} non-LoRA transformer keys in checkpoint") |
| if non_lora_keys_found: |
| print(f" Sample non-LoRA keys: {non_lora_keys_found[:10]}") |
| |
| |
| if rank == 0: |
| print(f" Merging {len(modules_to_merge)} modules...") |
| |
| import torch |
| for module_key, weights in modules_to_merge.items(): |
| |
| if weights['base_weight'] is not None: |
| base_key, base_weight = weights['base_weight'] |
| base_weight = base_weight.clone() |
| |
| if weights['lora_A'] is not None and weights['lora_B'] is not None: |
| lora_A_key, lora_A = weights['lora_A'] |
| lora_B_key, lora_B = weights['lora_B'] |
| |
| |
| |
| rank_value = lora_A.shape[0] |
| alpha = rank_value |
| |
| |
| |
| lora_delta = torch.matmul(lora_B, lora_A) |
| |
| if lora_delta.shape == base_weight.shape: |
| merged_weight = base_weight + lora_delta * (alpha / rank_value) |
| merged_state_dict[module_key] = merged_weight |
| if rank == 0 and len(modules_to_merge) <= 20: |
| print(f" ✓ Merged {module_key}: {base_weight.shape}") |
| else: |
| if rank == 0: |
| print(f" ⚠️ Shape mismatch for {module_key}: base={base_weight.shape}, lora_delta={lora_delta.shape}, using base only") |
| merged_state_dict[module_key] = base_weight |
| else: |
| |
| merged_state_dict[module_key] = base_weight |
| |
| |
| if '.bias' in module_key and weights['base_bias'] is not None: |
| bias_key, base_bias = weights['base_bias'] |
| merged_state_dict[module_key] = base_bias.clone() |
| |
| if rank == 0: |
| print(f" Merged {len(merged_state_dict)} weights") |
| print(f" Sample merged keys: {list(merged_state_dict.keys())[:5]}") |
| |
| |
| try: |
| missing_keys, unexpected_keys = pipeline.transformer.load_state_dict(merged_state_dict, strict=False) |
| |
| if rank == 0: |
| print(f" Loaded merged weights:") |
| print(f" Missing keys: {len(missing_keys)}") |
| print(f" Unexpected keys: {len(unexpected_keys)}") |
| if missing_keys: |
| print(f" Missing keys: {missing_keys}") |
| |
| critical_keys = ['pos_embed', 'time_text_embed', 'context_embedder', 'norm_out', 'proj_out'] |
| has_critical = any(any(ck in mk for ck in critical_keys) for mk in missing_keys) |
| if has_critical: |
| print(f" ⚠️ WARNING: Missing critical keys! These should be loaded from pretrained model.") |
| print(f" The missing keys will use values from the pretrained model (not fine-tuned).") |
| |
| |
| if len(missing_keys) > 0: |
| |
| |
| if rank == 0: |
| print(f" Note: Missing keys will use pretrained model weights (not fine-tuned)") |
| |
| if rank == 0: |
| print(f" ✓ Successfully loaded merged model weights") |
| return True |
| |
| except Exception as e: |
| if rank == 0: |
| print(f" ❌ Error loading merged weights: {e}") |
| import traceback |
| traceback.print_exc() |
| return False |
| |
| |
| elif len(non_lora_transformer_keys) > 0: |
| if rank == 0: |
| print(f"✓ Detected merged model weights (contains full transformer weights)") |
| print(f" Loading full model weights directly...") |
| |
| |
| transformer_state_dict = {} |
| for key, value in state_dict.items(): |
| |
| new_key = key |
| if key.startswith('base_model.model.transformer.'): |
| new_key = key[len('base_model.model.transformer.'):] |
| elif key.startswith('model.transformer.'): |
| new_key = key[len('model.transformer.'):] |
| elif key.startswith('transformer.'): |
| new_key = key[len('transformer.'):] |
| |
| |
| |
| if (new_key.startswith('transformer_blocks') or |
| new_key.startswith('pos_embed') or |
| new_key.startswith('time_text_embed') or |
| 'lora' in new_key.lower()): |
| transformer_state_dict[new_key] = value |
| |
| if rank == 0: |
| print(f" Extracted {len(transformer_state_dict)} transformer weight keys") |
| print(f" Sample keys: {list(transformer_state_dict.keys())[:5]}") |
| |
| |
| try: |
| missing_keys, unexpected_keys = pipeline.transformer.load_state_dict(transformer_state_dict, strict=False) |
| |
| if rank == 0: |
| print(f" Loaded full model weights:") |
| print(f" Missing keys: {len(missing_keys)}") |
| print(f" Unexpected keys: {len(unexpected_keys)}") |
| if missing_keys: |
| print(f" Sample missing keys: {missing_keys[:5]}") |
| if unexpected_keys: |
| print(f" Sample unexpected keys: {unexpected_keys[:5]}") |
| |
| |
| if len(missing_keys) > len(transformer_state_dict) * 0.5: |
| if rank == 0: |
| print(f" ⚠️ WARNING: Too many missing keys, weights may not be fully loaded") |
| return False |
| |
| if rank == 0: |
| print(f" ✓ Successfully loaded merged model weights") |
| return True |
| |
| except Exception as e: |
| if rank == 0: |
| print(f" ❌ Error loading full model weights: {e}") |
| import traceback |
| traceback.print_exc() |
| return False |
| |
| |
| if rank == 0: |
| print(f"Detected LoRA-only weights, loading as LoRA adapter...") |
| |
| |
| detected_rank = None |
| for key, value in state_dict.items(): |
| if 'lora_A' in key and 'transformer' in key and len(value.shape) == 2: |
| |
| detected_rank = value.shape[0] |
| if rank == 0: |
| print(f"✓ Detected LoRA rank from checkpoint: {detected_rank} (from key: {key})") |
| break |
| |
| |
| actual_rank = detected_rank if detected_rank is not None else lora_rank |
| if detected_rank is not None and detected_rank != lora_rank: |
| if rank == 0: |
| print(f"⚠️ Warning: Detected rank ({detected_rank}) differs from requested rank ({lora_rank}), using detected rank") |
| |
| |
| |
| if hasattr(pipeline.transformer, 'peft_config') and pipeline.transformer.peft_config: |
| if "default" in pipeline.transformer.peft_config: |
| if rank == 0: |
| print("Removing existing 'default' adapter before adding new one...") |
| try: |
| |
| pipeline.unload_lora_weights() |
| if rank == 0: |
| print("Successfully unloaded existing LoRA adapter") |
| except Exception as e: |
| if rank == 0: |
| print(f"❌ ERROR: Could not unload existing adapter: {e}") |
| print("Cannot proceed without cleaning up adapter") |
| return False |
| |
| |
| |
| transformer_lora_config = LoraConfig( |
| r=actual_rank, |
| lora_alpha=actual_rank, |
| init_lora_weights="gaussian", |
| target_modules=["attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0"], |
| ) |
| |
| |
| pipeline.transformer.add_adapter(transformer_lora_config) |
| |
| if rank == 0: |
| print(f"LoRA adapter configured with rank={actual_rank}") |
| |
| |
| |
| |
| |
| |
| |
| lora_state_dict = {} |
| for key, value in state_dict.items(): |
| if 'lora' in key.lower() and 'transformer' in key.lower(): |
| |
| new_key = key |
| |
| |
| |
| |
| if key.startswith('base_model.model.transformer.'): |
| new_key = key[len('base_model.model.transformer.'):] |
| elif key.startswith('model.transformer.'): |
| new_key = key[len('model.transformer.'):] |
| elif key.startswith('transformer.'): |
| |
| |
| if not key[len('transformer.'):].startswith('transformer_blocks'): |
| new_key = key[len('transformer.'):] |
| else: |
| new_key = key[len('transformer.'):] |
| |
| |
| if 'transformer_blocks' in new_key or 'transformer' in new_key: |
| lora_state_dict[new_key] = value |
| |
| if not lora_state_dict: |
| if rank == 0: |
| print("No LoRA weights found in checkpoint") |
| |
| all_keys = list(state_dict.keys()) |
| print(f"Total keys: {len(all_keys)}") |
| print(f"First 20 keys: {all_keys[:20]}") |
| |
| lora_related = [k for k in all_keys if 'lora' in k.lower()] |
| if lora_related: |
| print(f"Keys containing 'lora': {lora_related[:10]}") |
| return False |
| |
| if rank == 0: |
| print(f"Found {len(lora_state_dict)} LoRA weight keys") |
| sample_keys = list(lora_state_dict.keys())[:5] |
| print(f"Sample LoRA keys: {sample_keys}") |
| |
| |
| |
| |
| try: |
| |
| sample_key = list(lora_state_dict.keys())[0] if lora_state_dict else "" |
| |
| if rank == 0: |
| print(f"Original key format: {sample_key}") |
| |
| |
| |
| |
| |
| |
| |
| sample_key = list(lora_state_dict.keys())[0] if lora_state_dict else "" |
| has_default_suffix = '.default.weight' in sample_key or '.default.bias' in sample_key |
| |
| if rank == 0: |
| print(f"Sample key: {sample_key}") |
| print(f"Has .default suffix: {has_default_suffix}") |
| |
| |
| |
| converted_dict = {} |
| for key, value in lora_state_dict.items(): |
| |
| |
| new_key = key |
| if '.default.weight' in new_key: |
| new_key = new_key.replace('.default.weight', '.weight') |
| elif '.default.bias' in new_key: |
| new_key = new_key.replace('.default.bias', '.bias') |
| elif '.default' in new_key and (new_key.endswith('.weight') or new_key.endswith('.bias')): |
| |
| new_key = new_key.replace('.default', '') |
| |
| converted_dict[new_key] = value |
| |
| if rank == 0: |
| print(f"Converted {len(converted_dict)} keys (removed .default suffix if present)") |
| print(f"Sample converted keys: {list(converted_dict.keys())[:5]}") |
| |
| |
| incompatible_keys = set_peft_model_state_dict( |
| pipeline.transformer, |
| converted_dict, |
| adapter_name="default" |
| ) |
| |
| |
| if incompatible_keys is not None: |
| missing_keys = getattr(incompatible_keys, "missing_keys", []) |
| unexpected_keys = getattr(incompatible_keys, "unexpected_keys", []) |
| |
| if rank == 0: |
| print(f"LoRA loading result:") |
| print(f" Missing keys: {len(missing_keys)}") |
| print(f" Unexpected keys: {len(unexpected_keys)}") |
| |
| if len(missing_keys) > 100: |
| print(f" ⚠️ WARNING: Too many missing keys ({len(missing_keys)}), LoRA may not be fully loaded!") |
| print(f" Sample missing keys: {missing_keys[:10]}") |
| elif missing_keys: |
| print(f" Sample missing keys: {missing_keys[:10]}") |
| |
| if unexpected_keys: |
| print(f" Unexpected keys: {unexpected_keys[:10]}") |
| |
| |
| if len(missing_keys) > len(converted_dict) * 0.5: |
| if rank == 0: |
| print("❌ ERROR: Too many missing keys, LoRA weights not loaded correctly!") |
| return False |
| else: |
| if rank == 0: |
| print("✓ LoRA weights loaded (no incompatible keys reported)") |
| |
| except RuntimeError as e: |
| |
| error_str = str(e) |
| if "size mismatch" in error_str: |
| if rank == 0: |
| print(f"❌ Size mismatch error: The checkpoint rank doesn't match the adapter rank") |
| print(f" This usually means the checkpoint was trained with a different rank") |
| |
| import re |
| |
| match = re.search(r'copying a param with shape torch\.Size\(\[(\d+),', error_str) |
| if match: |
| checkpoint_rank = int(match.group(1)) |
| if rank == 0: |
| print(f" Detected checkpoint rank: {checkpoint_rank}") |
| print(f" Adapter was configured with rank: {actual_rank}") |
| if checkpoint_rank != actual_rank: |
| print(f" ⚠️ Mismatch! Need to recreate adapter with rank={checkpoint_rank}") |
| else: |
| if rank == 0: |
| print(f"❌ Error setting LoRA state dict: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| try: |
| pipeline.unload_lora_weights() |
| except: |
| pass |
| return False |
| except Exception as e: |
| if rank == 0: |
| print(f"❌ Error setting LoRA state dict: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| try: |
| pipeline.unload_lora_weights() |
| except: |
| pass |
| return False |
| |
| |
| pipeline.transformer.set_adapter("default") |
| |
| |
| if hasattr(pipeline.transformer, 'peft_config'): |
| adapters = list(pipeline.transformer.peft_config.keys()) |
| if rank == 0: |
| print(f"LoRA adapters configured: {adapters}") |
| |
| if hasattr(pipeline.transformer, 'active_adapters'): |
| |
| try: |
| if callable(pipeline.transformer.active_adapters): |
| active = pipeline.transformer.active_adapters() |
| else: |
| active = pipeline.transformer.active_adapters |
| if rank == 0: |
| print(f"Active adapters: {active}") |
| except: |
| if rank == 0: |
| print("Could not get active adapters, but LoRA is configured") |
| |
| |
| |
| lora_layers_found = 0 |
| nonzero_lora_layers = 0 |
| total_lora_weight_sum = 0.0 |
| |
| for name, module in pipeline.transformer.named_modules(): |
| if 'lora_A' in name or 'lora_B' in name: |
| lora_layers_found += 1 |
| if hasattr(module, 'weight') and module.weight is not None: |
| weight_sum = module.weight.abs().sum().item() |
| total_lora_weight_sum += weight_sum |
| if weight_sum > 1e-6: |
| nonzero_lora_layers += 1 |
| if rank == 0 and nonzero_lora_layers <= 3: |
| print(f"✓ Found non-zero LoRA weight in: {name}, sum={weight_sum:.6f}") |
| |
| if rank == 0: |
| print(f"LoRA verification:") |
| print(f" Total LoRA layers found: {lora_layers_found}") |
| print(f" Non-zero LoRA layers: {nonzero_lora_layers}") |
| print(f" Total LoRA weight sum: {total_lora_weight_sum:.6f}") |
| |
| if lora_layers_found == 0: |
| print("❌ ERROR: No LoRA layers found in transformer!") |
| return False |
| elif nonzero_lora_layers == 0: |
| print("❌ ERROR: All LoRA weights are zero, LoRA not loaded correctly!") |
| return False |
| elif nonzero_lora_layers < lora_layers_found * 0.5: |
| print(f"⚠️ WARNING: Only {nonzero_lora_layers}/{lora_layers_found} LoRA layers have non-zero weights!") |
| print("⚠️ LoRA may not be fully applied!") |
| else: |
| print(f"✓ LoRA weights verified: {nonzero_lora_layers}/{lora_layers_found} layers have non-zero weights") |
| |
| if nonzero_lora_layers == 0: |
| return False |
| |
| if rank == 0: |
| print("✓ Successfully loaded and verified LoRA weights from checkpoint") |
| |
| return True |
| |
| except Exception as e: |
| if rank == 0: |
| print(f"Error loading LoRA from checkpoint: {e}") |
| import traceback |
| traceback.print_exc() |
| return False |
|
|
|
|
| def load_captions_from_jsonl(jsonl_path): |
| captions = [] |
| with open(jsonl_path, 'r', encoding='utf-8') as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| data = json.loads(line) |
| cap = None |
| for field in ['caption', 'text', 'prompt', 'description']: |
| if field in data and isinstance(data[field], str): |
| cap = data[field].strip() |
| break |
| if cap: |
| captions.append(cap) |
| except Exception: |
| continue |
| return captions if captions else ["a beautiful high quality image"] |
|
|
|
|
| def main(args): |
| assert torch.cuda.is_available(), "需要GPU运行" |
| dist.init_process_group("nccl") |
| rank = dist.get_rank() |
| world_size = dist.get_world_size() |
| device = rank % torch.cuda.device_count() |
| torch.cuda.set_device(device) |
| seed = args.global_seed * world_size + rank |
| torch.manual_seed(seed) |
|
|
| print(f"[rank{rank}] DDP initialized, device={device}, seed={seed}, world_size={world_size}") |
|
|
| |
| if rank == 0: |
| print("=" * 80) |
| print("参数检查:") |
| print(f" lora_path: {args.lora_path}") |
| print(f" rectified_weights: {args.rectified_weights}") |
| print(f" lora_path is None: {args.lora_path is None}") |
| print(f" lora_path is empty: {args.lora_path == '' if args.lora_path else 'N/A'}") |
| print(f" rectified_weights is None: {args.rectified_weights is None}") |
| print(f" rectified_weights is empty: {args.rectified_weights == '' if args.rectified_weights else 'N/A'}") |
| print("=" * 80) |
|
|
| |
| RectifiedNoiseModule, SD3WithRectifiedNoise = dynamic_import_training_classes(str(Path(__file__).parent)) |
|
|
| |
| dtype = torch.float16 if args.mixed_precision == "fp16" else (torch.bfloat16 if args.mixed_precision == "bf16" else torch.float32) |
| if rank == 0: |
| print(f"Loading SD3 pipeline from {args.pretrained_model_name_or_path} (dtype={dtype})") |
| pipeline = StableDiffusion3Pipeline.from_pretrained( |
| args.pretrained_model_name_or_path, |
| revision=args.revision, |
| variant=args.variant, |
| torch_dtype=dtype, |
| ).to(device) |
|
|
| print(f"[rank{rank}] Pipeline loaded and moved to device {device}") |
|
|
| |
| lora_loaded = False |
| if args.lora_path: |
| if rank == 0: |
| print(f"Attempting to load LoRA weights from: {args.lora_path}") |
| print(f"LoRA path exists: {os.path.exists(args.lora_path) if args.lora_path else False}") |
| |
| |
| if check_lora_weights_exist(args.lora_path): |
| if rank == 0: |
| print("Found standard LoRA weights, loading...") |
| try: |
| |
| if rank == 0: |
| sample_param_before = next(iter(pipeline.transformer.parameters())).clone() |
| print(f"Sample transformer param before LoRA (first 5 values): {sample_param_before.flatten()[:5]}") |
| |
| pipeline.load_lora_weights(args.lora_path) |
| lora_loaded = True |
| |
| |
| if rank == 0: |
| sample_param_after = next(iter(pipeline.transformer.parameters())).clone() |
| param_diff = (sample_param_after - sample_param_before).abs().max().item() |
| print(f"Sample transformer param after LoRA (first 5 values): {sample_param_after.flatten()[:5]}") |
| print(f"Max parameter change after LoRA loading: {param_diff}") |
| if param_diff < 1e-6: |
| print("⚠️ WARNING: LoRA weights may not have been applied (parameter change is very small)") |
| else: |
| print("✓ LoRA weights appear to have been applied") |
| |
| |
| if hasattr(pipeline.transformer, 'peft_config'): |
| print(f"✓ PEFT config found: {list(pipeline.transformer.peft_config.keys())}") |
| else: |
| print("⚠️ WARNING: No peft_config found after loading LoRA") |
| |
| if rank == 0: |
| print("LoRA loaded successfully from standard format.") |
| except Exception as e: |
| if rank == 0: |
| print(f"Failed to load LoRA from standard format: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| |
| if not lora_loaded and os.path.isdir(args.lora_path): |
| if rank == 0: |
| print("Standard LoRA weights not found, trying accelerator checkpoint format...") |
| |
| |
| |
| detected_rank = None |
| try: |
| from safetensors.torch import load_file |
| model_file = os.path.join(args.lora_path, "model.safetensors") |
| if os.path.exists(model_file): |
| state_dict = load_file(model_file) |
| |
| for key, value in state_dict.items(): |
| if 'lora_A' in key and 'transformer' in key and len(value.shape) == 2: |
| |
| detected_rank = value.shape[0] |
| if rank == 0: |
| print(f"✓ Detected LoRA rank from checkpoint: {detected_rank} (from key: {key})") |
| break |
| except Exception as e: |
| if rank == 0: |
| print(f"Could not detect rank from checkpoint: {e}") |
| |
| |
| |
| |
| if detected_rank is not None: |
| rank_list = [detected_rank] |
| if rank == 0: |
| print(f"Using detected rank: {detected_rank}") |
| else: |
| |
| rank_list = [] |
| |
| if hasattr(args, 'lora_rank') and args.lora_rank: |
| rank_list.append(args.lora_rank) |
| |
| for r in [32, 64, 16, 128]: |
| if r not in rank_list: |
| rank_list.append(r) |
| if rank == 0: |
| print(f"Rank detection failed, will try ranks in order: {rank_list}") |
| |
| |
| for lora_rank in rank_list: |
| |
| |
| if hasattr(pipeline.transformer, 'peft_config') and pipeline.transformer.peft_config: |
| if "default" in pipeline.transformer.peft_config: |
| try: |
| |
| pipeline.unload_lora_weights() |
| if rank == 0: |
| print(f"Cleaned up existing adapter before trying rank={lora_rank}") |
| except Exception as e: |
| if rank == 0: |
| print(f"Warning: Could not unload adapter: {e}") |
| |
| if rank == 0: |
| print("⚠️ WARNING: Cannot unload adapter, will recreate pipeline...") |
| |
| try: |
| pipeline = StableDiffusion3Pipeline.from_pretrained( |
| args.pretrained_model_name_or_path, |
| revision=args.revision, |
| variant=args.variant, |
| torch_dtype=dtype, |
| ).to(device) |
| if rank == 0: |
| print("Pipeline recreated to clear adapter state") |
| except Exception as e2: |
| if rank == 0: |
| print(f"Failed to recreate pipeline: {e2}") |
| |
| if rank == 0: |
| print(f"Trying to load with LoRA rank={lora_rank}...") |
| lora_loaded = load_lora_from_checkpoint(pipeline, args.lora_path, rank=rank, lora_rank=lora_rank) |
| if lora_loaded: |
| if rank == 0: |
| print(f"✓ Successfully loaded LoRA with rank={lora_rank}") |
| break |
| elif rank == 0: |
| print(f"✗ Failed to load with rank={lora_rank}, trying next rank...") |
| |
| |
| if not lora_loaded and os.path.isdir(args.lora_path): |
| |
| output_dir = os.path.dirname(args.lora_path.rstrip('/')) |
| if output_dir and os.path.exists(output_dir): |
| if rank == 0: |
| print(f"Trying to load standard LoRA weights from output directory: {output_dir}") |
| if check_lora_weights_exist(output_dir): |
| try: |
| pipeline.load_lora_weights(output_dir) |
| lora_loaded = True |
| if rank == 0: |
| print("LoRA loaded successfully from output directory.") |
| except Exception as e: |
| if rank == 0: |
| print(f"Failed to load LoRA from output directory: {e}") |
| |
| if not lora_loaded: |
| if rank == 0: |
| print(f"⚠️ WARNING: Failed to load LoRA weights from {args.lora_path}, using baseline model") |
| else: |
| |
| if rank == 0: |
| print("=" * 80) |
| print("LoRA 加载验证:") |
| if hasattr(pipeline.transformer, 'peft_config') and pipeline.transformer.peft_config: |
| print(f" ✓ PEFT config exists: {list(pipeline.transformer.peft_config.keys())}") |
| |
| lora_layers_found = 0 |
| for name, module in pipeline.transformer.named_modules(): |
| if 'lora_A' in name or 'lora_B' in name: |
| lora_layers_found += 1 |
| if lora_layers_found <= 3: |
| if hasattr(module, 'weight'): |
| weight_sum = module.weight.abs().sum().item() if module.weight is not None else 0 |
| print(f" ✓ Found LoRA layer: {name}, weight_sum={weight_sum:.6f}") |
| print(f" ✓ Total LoRA layers found: {lora_layers_found}") |
| if lora_layers_found == 0: |
| print(" ⚠️ WARNING: No LoRA layers found in transformer!") |
| else: |
| print(" ⚠️ WARNING: No PEFT config found - LoRA may not be active!") |
| print("=" * 80) |
|
|
| |
| |
| use_rectified = False |
| rectified_weights_path = None |
| if args.rectified_weights: |
| rectified_weights_str = str(args.rectified_weights).strip() |
| if rectified_weights_str: |
| use_rectified = True |
| rectified_weights_path = rectified_weights_str |
| |
| if rank == 0: |
| print(f"use_rectified: {use_rectified}, rectified_weights_path: {rectified_weights_path}") |
| |
| if use_rectified: |
| if rank == 0: |
| print(f"Using Rectified Noise module with weights from: {rectified_weights_path}") |
| print(f"[rank{rank}] RectifiedNoiseModule configuration: num_sit_layers={args.num_sit_layers}") |
| |
| |
| tfm = pipeline.transformer |
| if hasattr(tfm.config, 'joint_attention_dim') and tfm.config.joint_attention_dim is not None: |
| sit_hidden_size = tfm.config.joint_attention_dim |
| elif hasattr(tfm.config, 'inner_dim') and tfm.config.inner_dim is not None: |
| sit_hidden_size = tfm.config.inner_dim |
| elif hasattr(tfm.config, 'hidden_size') and tfm.config.hidden_size is not None: |
| sit_hidden_size = tfm.config.hidden_size |
| else: |
| sit_hidden_size = 4096 |
|
|
| transformer_hidden_size = getattr(tfm.config, 'hidden_size', 1536) |
| num_attention_heads = getattr(tfm.config, 'num_attention_heads', 32) |
| input_dim = getattr(tfm.config, 'in_channels', 16) |
|
|
| rectified_module = RectifiedNoiseModule( |
| hidden_size=sit_hidden_size, |
| num_sit_layers=args.num_sit_layers, |
| num_attention_heads=num_attention_heads, |
| input_dim=input_dim, |
| transformer_hidden_size=transformer_hidden_size, |
| ) |
| |
| ok = load_sit_weights(rectified_module, rectified_weights_path, rank=rank) |
| if rank == 0: |
| if not ok: |
| print("⚠️ Warning: Failed to load rectified weights, will use baseline model without rectified noise") |
| else: |
| print("✓ Successfully loaded rectified noise weights") |
|
|
| |
| |
| |
| |
| |
| if lora_loaded and rank == 0: |
| print("Creating SD3WithRectifiedNoise with LoRA-enabled transformer...") |
| elif rank == 0: |
| print("Creating SD3WithRectifiedNoise...") |
| |
| model = SD3WithRectifiedNoise(pipeline.transformer, rectified_module).to(device) |
| |
| |
| |
| |
| |
| |
| if lora_loaded: |
| |
| if hasattr(model.transformer, 'peft_config'): |
| try: |
| |
| model.transformer.set_adapter("default_0") |
| |
| |
| lora_layers_after_wrap = 0 |
| nonzero_after_wrap = 0 |
| for name, module in model.transformer.named_modules(): |
| if 'lora_A' in name or 'lora_B' in name: |
| lora_layers_after_wrap += 1 |
| if hasattr(module, 'weight') and module.weight is not None: |
| if module.weight.abs().sum().item() > 1e-6: |
| nonzero_after_wrap += 1 |
| |
| if rank == 0: |
| print(f"LoRA after SD3WithRectifiedNoise wrapping:") |
| print(f" LoRA layers: {lora_layers_after_wrap}, Non-zero: {nonzero_after_wrap}") |
| if nonzero_after_wrap == 0: |
| print(" ❌ ERROR: All LoRA weights are zero after wrapping!") |
| elif nonzero_after_wrap < lora_layers_after_wrap * 0.5: |
| print(f" ⚠️ WARNING: Only {nonzero_after_wrap}/{lora_layers_after_wrap} LoRA layers have weights!") |
| else: |
| print(f" ✓ LoRA weights preserved after wrapping") |
| |
| |
| if hasattr(model.transformer, 'active_adapters'): |
| try: |
| if callable(model.transformer.active_adapters): |
| active = model.transformer.active_adapters() |
| else: |
| active = model.transformer.active_adapters |
| if rank == 0: |
| print(f" Active adapters: {active}") |
| except: |
| if rank == 0: |
| print(" LoRA adapter re-enabled after model wrapping") |
| else: |
| if rank == 0: |
| print(" LoRA adapter re-enabled after model wrapping") |
| except Exception as e: |
| if rank == 0: |
| print(f"❌ ERROR: Could not re-enable LoRA adapter: {e}") |
| import traceback |
| traceback.print_exc() |
| else: |
| |
| |
| if rank == 0: |
| print("LoRA loaded via merged weights (no PEFT adapter needed)") |
| print(" ✓ LoRA weights are already merged into transformer base weights") |
| print(" Note: This is expected when loading from merged checkpoint format") |
| |
| |
| pipeline.model = model |
| |
| |
| model.eval() |
| model.transformer.eval() |
| else: |
| if rank == 0: |
| print("Not using Rectified Noise module, using baseline SD3 pipeline") |
| |
| |
| |
| |
| |
| if lora_loaded: |
| |
| transformer_ref = model.transformer if use_rectified else pipeline.transformer |
| |
| |
| if hasattr(transformer_ref, 'set_adapter'): |
| try: |
| transformer_ref.set_adapter("default") |
| except: |
| pass |
| |
| |
| if rank == 0: |
| |
| lora_found = False |
| for name, module in transformer_ref.named_modules(): |
| if 'lora_A' in name and 'default' in name and hasattr(module, 'weight'): |
| if module.weight is not None: |
| weight_sum = module.weight.abs().sum().item() |
| if weight_sum > 0: |
| print(f"✓ Verified LoRA weight in {name}: sum={weight_sum:.6f}") |
| lora_found = True |
| break |
| |
| if not lora_found: |
| print("⚠ Warning: Could not verify LoRA weights in model") |
| else: |
| |
| |
| for name, module in transformer_ref.named_modules(): |
| if hasattr(module, '__class__') and 'lora' in module.__class__.__name__.lower(): |
| if hasattr(module, 'lora_enabled'): |
| enabled = module.lora_enabled |
| if rank == 0: |
| print(f"✓ Found LoRA layer {name}, enabled: {enabled}") |
| break |
| |
| print("Model set to eval mode, LoRA should be active during inference") |
|
|
| |
| if args.enable_attention_slicing: |
| if rank == 0: |
| print("Enabling attention slicing to save memory") |
| pipeline.enable_attention_slicing() |
| |
| if args.enable_vae_slicing: |
| if rank == 0: |
| print("Enabling VAE slicing to save memory") |
| pipeline.enable_vae_slicing() |
| |
| if args.enable_cpu_offload: |
| if rank == 0: |
| print("Enabling CPU offload to save memory") |
| pipeline.enable_model_cpu_offload() |
| |
| |
| pipeline.set_progress_bar_config(disable=True) |
|
|
| |
| captions = load_captions_from_jsonl(args.captions_jsonl) |
| total_images_needed = min(len(captions) * args.images_per_caption, args.max_samples) |
|
|
| |
| if rank == 0: |
| os.makedirs(args.sample_dir, exist_ok=True) |
| dist.barrier() |
|
|
| |
| existing_count, max_existing_index = get_existing_sample_count(args.sample_dir) |
| if rank == 0: |
| print(f"Found {existing_count} existing samples, max index: {max_existing_index}") |
| |
| |
| remaining_images_needed = max(0, total_images_needed - existing_count) |
| if remaining_images_needed == 0: |
| if rank == 0: |
| print("All required samples already exist. Skipping generation.") |
| print(f"Creating npz from existing samples...") |
| create_npz_from_sample_folder(args.sample_dir, total_images_needed) |
| return |
|
|
| if rank == 0: |
| print(f"Need to generate {remaining_images_needed} more samples (total needed: {total_images_needed})") |
|
|
| n = args.per_proc_batch_size |
| global_batch = n * world_size |
| total_samples = int(math.ceil(remaining_images_needed / global_batch) * global_batch) |
| assert total_samples % world_size == 0 |
| samples_per_gpu = total_samples // world_size |
| assert samples_per_gpu % n == 0 |
| iterations = samples_per_gpu // n |
|
|
| if rank == 0: |
| print(f"Sampling remaining={remaining_images_needed}, total_samples={total_samples}, per_gpu={samples_per_gpu}, iterations={iterations}") |
|
|
| pbar = tqdm(range(iterations)) if rank == 0 else range(iterations) |
| saved = 0 |
|
|
| autocast_device = "cuda" if torch.cuda.is_available() else "cpu" |
| for it in pbar: |
| if rank == 0 and it % 10 == 0: |
| print(f"[rank{rank}] Sampling iteration {it}/{iterations}") |
| batch_prompts = [] |
| base_index = it * global_batch + rank |
| for j in range(n): |
| idx = it * global_batch + j * world_size + rank |
| if idx < remaining_images_needed: |
| cap_idx = idx // args.images_per_caption |
| batch_prompts.append(captions[cap_idx]) |
| else: |
| batch_prompts.append("a beautiful high quality image") |
|
|
| with torch.autocast(autocast_device, dtype=dtype): |
| images = [] |
| for k, prompt in enumerate(batch_prompts): |
| image_seed = seed + it * 10000 + k * 1000 + rank |
| generator = torch.Generator(device=device).manual_seed(image_seed) |
| img = pipeline( |
| prompt=prompt, |
| height=args.height, |
| width=args.width, |
| num_inference_steps=args.num_inference_steps, |
| guidance_scale=args.guidance_scale, |
| generator=generator, |
| num_images_per_prompt=1, |
| ).images[0] |
| images.append(img) |
|
|
| |
| out_dir = Path(args.sample_dir) |
| if rank == 0 and it == 0: |
| print(f"Saving pngs to: {out_dir}") |
| for j, img in enumerate(images): |
| global_index = it * global_batch + j * world_size + rank + existing_count |
| if global_index < total_images_needed: |
| filename = f"{global_index:07d}.png" |
| img.save(out_dir / filename) |
| saved += 1 |
| dist.barrier() |
|
|
| if rank == 0: |
| print(f"Done. Saved {saved * world_size} images in total.") |
| actual_num_samples = len([name for name in os.listdir(args.sample_dir) if name.endswith(".png")]) |
| print(f"Actually generated {actual_num_samples} images") |
| npz_samples = min(actual_num_samples, total_images_needed) |
| print(f"[rank{rank}] Creating npz from sample folder: {args.sample_dir}, npz_samples={npz_samples}") |
| create_npz_from_sample_folder(args.sample_dir, npz_samples) |
| print("Done creating npz.") |
| print("Done.") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="SD3 LoRA + RectifiedNoise 分布式采样脚本") |
| |
| parser.add_argument("--pretrained_model_name_or_path", type=str, required=True) |
| parser.add_argument("--revision", type=str, default=None) |
| parser.add_argument("--variant", type=str, default=None) |
| |
| parser.add_argument("--lora_path", type=str, default=None, help="LoRA 权重路径(文件或目录)") |
| parser.add_argument("--rectified_weights", type=str, default=None, help="Rectified(SIT) 权重路径(文件或目录)") |
| parser.add_argument("--num_sit_layers", type=int, default=1, help="与训练一致的 SIT 层数") |
| |
| parser.add_argument("--num_inference_steps", type=int, default=28) |
| parser.add_argument("--guidance_scale", type=float, default=7.0) |
| parser.add_argument("--height", type=int, default=1024) |
| parser.add_argument("--width", type=int, default=1024) |
| parser.add_argument("--per_proc_batch_size", type=int, default=1) |
| parser.add_argument("--images_per_caption", type=int, default=1) |
| parser.add_argument("--max_samples", type=int, default=10000) |
| parser.add_argument("--captions_jsonl", type=str, required=True) |
| parser.add_argument("--sample_dir", type=str, default="sd3_rectified_samples") |
| parser.add_argument("--global_seed", type=int, default=42) |
| parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"]) |
| |
| parser.add_argument("--enable_attention_slicing", action="store_true", help="启用 attention slicing 以节省显存") |
| parser.add_argument("--enable_vae_slicing", action="store_true", help="启用 VAE slicing 以节省显存") |
| parser.add_argument("--enable_cpu_offload", action="store_true", help="启用 CPU offload 以节省显存") |
|
|
| args = parser.parse_args() |
| main(args) |
|
|
|
|