| import torch |
| import os |
| from diffusers import UNet2DConditionModel |
| from diffusers.configuration_utils import FrozenDict |
| from typing import Dict |
|
|
| |
| UNET_PATH = './unet_old' |
| NEW_UNET_PATH = './unet' |
|
|
| |
|
|
| def migrate_unet_weights(old_unet: UNet2DConditionModel, new_unet: UNet2DConditionModel) -> UNet2DConditionModel: |
| """Копирует веса из старой UNet в новую, игнорируя новые слои.""" |
| old_state_dict = old_unet.state_dict() |
| new_state_dict = new_unet.state_dict() |
| |
| missing_keys = [] |
| |
| |
| for name, param in new_state_dict.items(): |
| if name in old_state_dict: |
| if param.shape == old_state_dict[name].shape: |
| param.data.copy_(old_state_dict[name].data) |
| else: |
| print(f"⚠️ Пропуск ключа {name}: не совпадают формы ({old_state_dict[name].shape} vs {param.shape})") |
| else: |
| missing_keys.append(name) |
| |
| print(f"✅ Успешно перенесено {len(new_state_dict) - len(missing_keys)} весов.") |
| print("\n--- Новые слои (случайные веса, требуют дообучения) ---") |
| for key in missing_keys: |
| print(f"🆕 {key}") |
| print("----------------------------------------------------------") |
| |
| return new_unet |
|
|
| |
|
|
| print(f"1. Загрузка UNet из {UNET_PATH}...") |
| try: |
| |
| old_unet = UNet2DConditionModel.from_pretrained(UNET_PATH, torch_dtype=torch.float32) |
| old_config: Dict = old_unet.config |
| print(" -> Исходная UNet успешно загружена.") |
| except Exception as e: |
| print(f"🛑 Ошибка при загрузке UNet: {e}") |
| exit() |
|
|
| |
| print("2. Модификация конфигурации...") |
|
|
| |
| new_config = dict(old_config) |
| new_config.update({ |
| |
| "addition_embed_type": "text", |
| |
| |
| "addition_time_embed_dim": 1024, |
| |
| |
| "encoder_hid_dim_type": "text_proj", |
| |
| |
| "encoder_hid_dim": 1024, |
| |
| |
| |
| "projection_class_embeddings_input_dim": 1024, |
| |
| |
| "time_embedding_dim": 1024, |
| |
| |
| "addition_embed_type_num_heads": 64, |
| }) |
|
|
| |
| print("3. Инициализация новой UNet с измененной архитектурой...") |
| new_config_frozen = FrozenDict(new_config) |
| |
| new_unet = UNet2DConditionModel.from_config(new_config_frozen) |
| print(" -> Новая UNet инициализирована (новые слои имеют случайные веса).") |
|
|
| |
|
|
| print("4. Выполнение миграции весов...") |
| migrated_unet = migrate_unet_weights(old_unet, new_unet) |
|
|
| |
|
|
| print(f"5. Сохранение новой UNet в папку {NEW_UNET_PATH}...") |
|
|
| |
| os.makedirs(NEW_UNET_PATH, exist_ok=True) |
|
|
| |
| migrated_unet.save_pretrained(NEW_UNET_PATH) |
|
|
| print("🎉 Готово! Новая UNet готова к использованию и дообучению.") |
|
|
| print(f"\nСледующий шаг: Замените путь к UNet в вашем SdxsPipeline на '{NEW_UNET_PATH}' и запустите инференс, чтобы убедиться, что она принимает `added_cond_kwargs` без ошибок.") |