{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "82ca7882-410c-4067-863a-07838d485f6a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "test unet\n", "Количество параметров: 1344407376\n", "Output shape: torch.Size([1, 16, 60, 48])\n", "Output shape: torch.Size([1, 16, 60, 48])\n" ] } ], "source": [ "config_sdxs = {\n", " # === Основные размеры и каналы ===\n", " \"in_channels\": 16, # Количество входных каналов (совместимость с 16-канальным VAE)\n", " \"out_channels\": 16, # Количество выходных каналов (симметрично in_channels)\n", " \"center_input_sample\": False, # Отключение центрирования входных данных (стандарт для диффузионных моделей)\n", " \"flip_sin_to_cos\": True, # Автоматическое преобразование sin/cos в эмбеддингах времени (для стабильности)\n", " \"freq_shift\": 0, # Сдвиг частоты (0 - стандартное значение для частотных эмбеддингов)\n", "\n", " # === Архитектура блоков ===\n", " \"down_block_types\": [ # Типы блоков энкодера (иерархия обработки):\n", " \"CrossAttnDownBlock2D\",\n", " \"CrossAttnDownBlock2D\",\n", " \"CrossAttnDownBlock2D\",\n", " \"DownBlock2D\"\n", " ],\n", " \"mid_block_type\": \"UNetMidBlock2DCrossAttn\", # Центральный блок с cross-attention (бутылочное горлышко сети)\n", " \"up_block_types\": [ # Типы блоков декодера (восстановление изображения):\n", " \"UpBlock2D\",\n", " \"CrossAttnUpBlock2D\",\n", " \"CrossAttnUpBlock2D\",\n", " \"CrossAttnUpBlock2D\",\n", " ],\n", " \"only_cross_attention\": False, # Использование как cross-attention, так и self-attention\n", "\n", " # === Конфигурация каналов ===\n", " \"block_out_channels\": [320, 640, 1280, 1280], \n", " \"layers_per_block\": 2, # Число слоев в блоках\n", " \"downsample_padding\": 1, # Паддинг при уменьшении разрешения\n", " \"mid_block_scale_factor\": 1.0, # Усиление сигнала в центральном блоке\n", "\n", " # === Нормализация ===\n", " \"norm_num_groups\": 32, # Число групп для GroupNorm (оптимально для стабильности)\n", " \"norm_eps\": 1e-05, # Эпсилон для нормализации (стандартное значение)\n", "\n", " # === Cross-Attention ===\n", " \"cross_attention_dim\": 768, # Размерность текстовых эмбеддинго\n", " \n", " \"transformer_layers_per_block\": 3, # Число трансформерных слоев (уменьшение с глубиной)\n", " \"attention_head_dim\": 8, # Размерность головы внимания \n", " \"dual_cross_attention\": False, # Отключение двойного внимания (упрощение архитектуры)\n", " \"use_linear_projection\": False, # Изменено на True для лучшей организации памяти\n", "\n", " # === ResNet Блоки ===\n", " \"resnet_time_scale_shift\": \"default\", # Способ интеграции временных эмбеддингов\n", " \"resnet_skip_time_act\": False, # Отключение активации в skip-соединениях\n", " \"resnet_out_scale_factor\": 1.0, # Коэффициент масштабирования выхода ResNet\n", "\n", " # === Временные эмбеддинги ===\n", " \"time_embedding_type\": \"positional\", # Тип временных эмбеддингов (стандартный подход)\n", "\n", " # === Свертки ===\n", " \"conv_in_kernel\": 3, # Ядро входной свертки (баланс между рецептивным полем и параметрами)\n", " \"conv_out_kernel\": 3, # Ядро выходной свертки (симметрично входной)\n", "}\n", "\n", "if 1:\n", " checkpoint_path = \"sd15_tmp\"#\"sdxs\"\n", " import torch\n", " from diffusers import UNet2DConditionModel\n", " print(\"test unet\")\n", " new_unet = UNet2DConditionModel(**config_sdxs).to(\"cuda\", dtype=torch.float16)\n", "\n", " assert all(ch % 32 == 0 for ch in new_unet.config[\"block_out_channels\"]), \"Каналы должны быть кратны 32\"\n", " num_params = sum(p.numel() for p in new_unet.parameters())\n", " print(f\"Количество параметров: {num_params}\")\n", "\n", " # Генерация тестового латента (640x512 в latent space)\n", " test_latent = torch.randn(1, 16, 60, 48).to(\"cuda\", dtype=torch.float16) # 60x48 ≈ 512px\n", " timesteps = torch.tensor([1]).to(\"cuda\", dtype=torch.float16)\n", " encoder_hidden_states = torch.randn(1, 77, 768).to(\"cuda\", dtype=torch.float16)\n", " \n", " with torch.no_grad():\n", " output = new_unet(\n", " test_latent, \n", " timesteps, \n", " encoder_hidden_states\n", " ).sample\n", " \n", " print(f\"Output shape: {output.shape}\") \n", " new_unet.save_pretrained(checkpoint_path)\n", " #print(new_unet)\n", " del new_unet\n", " torch.cuda.empty_cache()\n", " print(f\"Output shape: {output.shape}\") \n", " # Количество параметров: 1101998736 1344407376" ] }, { "cell_type": "code", "execution_count": 3, "id": "f980bb1a-9859-44c2-a2df-ff1b073bf435", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Перенос весов: 100%|██████████| 1006/1006 [00:00<00:00, 36208.99it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Статистика переноса: {'перенесено': 1006, 'несовпадение_размеров': 0, 'пропущено': 0}\n", "Неперенесенные ключи в новой модели:\n", "down_blocks.0.attentions.0.transformer_blocks.2.attn1.to_k.weight\n", "down_blocks.0.attentions.0.transformer_blocks.2.attn1.to_out.0.bias\n", "down_blocks.0.attentions.0.transformer_blocks.2.attn1.to_out.0.weight\n", "down_blocks.0.attentions.0.transformer_blocks.2.attn1.to_q.weight\n", "down_blocks.0.attentions.0.transformer_blocks.2.attn1.to_v.weight\n", "down_blocks.0.attentions.0.transformer_blocks.2.attn2.to_k.weight\n", "down_blocks.0.attentions.0.transformer_blocks.2.attn2.to_out.0.bias\n", "down_blocks.0.attentions.0.transformer_blocks.2.attn2.to_out.0.weight\n", "down_blocks.0.attentions.0.transformer_blocks.2.attn2.to_q.weight\n", "down_blocks.0.attentions.0.transformer_blocks.2.attn2.to_v.weight\n", "down_blocks.0.attentions.0.transformer_blocks.2.ff.net.0.proj.bias\n", "down_blocks.0.attentions.0.transformer_blocks.2.ff.net.0.proj.weight\n", "down_blocks.0.attentions.0.transformer_blocks.2.ff.net.2.bias\n", "down_blocks.0.attentions.0.transformer_blocks.2.ff.net.2.weight\n", "down_blocks.0.attentions.0.transformer_blocks.2.norm1.bias\n", "down_blocks.0.attentions.0.transformer_blocks.2.norm1.weight\n", "down_blocks.0.attentions.0.transformer_blocks.2.norm2.bias\n", "down_blocks.0.attentions.0.transformer_blocks.2.norm2.weight\n", "down_blocks.0.attentions.0.transformer_blocks.2.norm3.bias\n", "down_blocks.0.attentions.0.transformer_blocks.2.norm3.weight\n", "down_blocks.0.attentions.1.transformer_blocks.2.attn1.to_k.weight\n", "down_blocks.0.attentions.1.transformer_blocks.2.attn1.to_out.0.bias\n", "down_blocks.0.attentions.1.transformer_blocks.2.attn1.to_out.0.weight\n", "down_blocks.0.attentions.1.transformer_blocks.2.attn1.to_q.weight\n", "down_blocks.0.attentions.1.transformer_blocks.2.attn1.to_v.weight\n", "down_blocks.0.attentions.1.transformer_blocks.2.attn2.to_k.weight\n", "down_blocks.0.attentions.1.transformer_blocks.2.attn2.to_out.0.bias\n", "down_blocks.0.attentions.1.transformer_blocks.2.attn2.to_out.0.weight\n", "down_blocks.0.attentions.1.transformer_blocks.2.attn2.to_q.weight\n", "down_blocks.0.attentions.1.transformer_blocks.2.attn2.to_v.weight\n", "down_blocks.0.attentions.1.transformer_blocks.2.ff.net.0.proj.bias\n", "down_blocks.0.attentions.1.transformer_blocks.2.ff.net.0.proj.weight\n", "down_blocks.0.attentions.1.transformer_blocks.2.ff.net.2.bias\n", "down_blocks.0.attentions.1.transformer_blocks.2.ff.net.2.weight\n", "down_blocks.0.attentions.1.transformer_blocks.2.norm1.bias\n", "down_blocks.0.attentions.1.transformer_blocks.2.norm1.weight\n", "down_blocks.0.attentions.1.transformer_blocks.2.norm2.bias\n", "down_blocks.0.attentions.1.transformer_blocks.2.norm2.weight\n", "down_blocks.0.attentions.1.transformer_blocks.2.norm3.bias\n", "down_blocks.0.attentions.1.transformer_blocks.2.norm3.weight\n", "down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_k.weight\n", "down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.bias\n", "down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.weight\n", "down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_q.weight\n", "down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_v.weight\n", "down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_k.weight\n", "down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.bias\n", "down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.weight\n", "down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_q.weight\n", "down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_v.weight\n", "down_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.bias\n", "down_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.weight\n", "down_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.bias\n", "down_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.weight\n", "down_blocks.1.attentions.0.transformer_blocks.2.norm1.bias\n", "down_blocks.1.attentions.0.transformer_blocks.2.norm1.weight\n", "down_blocks.1.attentions.0.transformer_blocks.2.norm2.bias\n", "down_blocks.1.attentions.0.transformer_blocks.2.norm2.weight\n", "down_blocks.1.attentions.0.transformer_blocks.2.norm3.bias\n", "down_blocks.1.attentions.0.transformer_blocks.2.norm3.weight\n", "down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_k.weight\n", "down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.bias\n", "down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.weight\n", "down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_q.weight\n", "down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_v.weight\n", "down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_k.weight\n", "down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.bias\n", "down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.weight\n", "down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_q.weight\n", "down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_v.weight\n", "down_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.bias\n", "down_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.weight\n", "down_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.bias\n", "down_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.weight\n", "down_blocks.1.attentions.1.transformer_blocks.2.norm1.bias\n", "down_blocks.1.attentions.1.transformer_blocks.2.norm1.weight\n", "down_blocks.1.attentions.1.transformer_blocks.2.norm2.bias\n", "down_blocks.1.attentions.1.transformer_blocks.2.norm2.weight\n", "down_blocks.1.attentions.1.transformer_blocks.2.norm3.bias\n", "down_blocks.1.attentions.1.transformer_blocks.2.norm3.weight\n", "down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_k.weight\n", "down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.bias\n", "down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.weight\n", "down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_q.weight\n", "down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_v.weight\n", "down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_k.weight\n", "down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.bias\n", "down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.weight\n", "down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_q.weight\n", "down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_v.weight\n", "down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.bias\n", "down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.weight\n", "down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.bias\n", "down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.weight\n", "down_blocks.2.attentions.0.transformer_blocks.2.norm1.bias\n", "down_blocks.2.attentions.0.transformer_blocks.2.norm1.weight\n", "down_blocks.2.attentions.0.transformer_blocks.2.norm2.bias\n", "down_blocks.2.attentions.0.transformer_blocks.2.norm2.weight\n", "down_blocks.2.attentions.0.transformer_blocks.2.norm3.bias\n", "down_blocks.2.attentions.0.transformer_blocks.2.norm3.weight\n", "down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_k.weight\n", "down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.bias\n", "down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.weight\n", "down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_q.weight\n", "down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_v.weight\n", "down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_k.weight\n", "down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.bias\n", "down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.weight\n", "down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_q.weight\n", "down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_v.weight\n", "down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.bias\n", "down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.weight\n", "down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.bias\n", "down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.weight\n", "down_blocks.2.attentions.1.transformer_blocks.2.norm1.bias\n", "down_blocks.2.attentions.1.transformer_blocks.2.norm1.weight\n", "down_blocks.2.attentions.1.transformer_blocks.2.norm2.bias\n", "down_blocks.2.attentions.1.transformer_blocks.2.norm2.weight\n", "down_blocks.2.attentions.1.transformer_blocks.2.norm3.bias\n", "down_blocks.2.attentions.1.transformer_blocks.2.norm3.weight\n", "mid_block.attentions.0.transformer_blocks.2.attn1.to_k.weight\n", "mid_block.attentions.0.transformer_blocks.2.attn1.to_out.0.bias\n", "mid_block.attentions.0.transformer_blocks.2.attn1.to_out.0.weight\n", "mid_block.attentions.0.transformer_blocks.2.attn1.to_q.weight\n", "mid_block.attentions.0.transformer_blocks.2.attn1.to_v.weight\n", "mid_block.attentions.0.transformer_blocks.2.attn2.to_k.weight\n", "mid_block.attentions.0.transformer_blocks.2.attn2.to_out.0.bias\n", "mid_block.attentions.0.transformer_blocks.2.attn2.to_out.0.weight\n", "mid_block.attentions.0.transformer_blocks.2.attn2.to_q.weight\n", "mid_block.attentions.0.transformer_blocks.2.attn2.to_v.weight\n", "mid_block.attentions.0.transformer_blocks.2.ff.net.0.proj.bias\n", "mid_block.attentions.0.transformer_blocks.2.ff.net.0.proj.weight\n", "mid_block.attentions.0.transformer_blocks.2.ff.net.2.bias\n", "mid_block.attentions.0.transformer_blocks.2.ff.net.2.weight\n", "mid_block.attentions.0.transformer_blocks.2.norm1.bias\n", "mid_block.attentions.0.transformer_blocks.2.norm1.weight\n", "mid_block.attentions.0.transformer_blocks.2.norm2.bias\n", "mid_block.attentions.0.transformer_blocks.2.norm2.weight\n", "mid_block.attentions.0.transformer_blocks.2.norm3.bias\n", "mid_block.attentions.0.transformer_blocks.2.norm3.weight\n", "up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_k.weight\n", "up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.bias\n", "up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.weight\n", "up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_q.weight\n", "up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_v.weight\n", "up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_k.weight\n", "up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.bias\n", "up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.weight\n", "up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_q.weight\n", "up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_v.weight\n", "up_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.bias\n", "up_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.weight\n", "up_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.bias\n", "up_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.weight\n", "up_blocks.1.attentions.0.transformer_blocks.2.norm1.bias\n", "up_blocks.1.attentions.0.transformer_blocks.2.norm1.weight\n", "up_blocks.1.attentions.0.transformer_blocks.2.norm2.bias\n", "up_blocks.1.attentions.0.transformer_blocks.2.norm2.weight\n", "up_blocks.1.attentions.0.transformer_blocks.2.norm3.bias\n", "up_blocks.1.attentions.0.transformer_blocks.2.norm3.weight\n", "up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_k.weight\n", "up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.bias\n", "up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.weight\n", "up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_q.weight\n", "up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_v.weight\n", "up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_k.weight\n", "up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.bias\n", "up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.weight\n", "up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_q.weight\n", "up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_v.weight\n", "up_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.bias\n", "up_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.weight\n", "up_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.bias\n", "up_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.weight\n", "up_blocks.1.attentions.1.transformer_blocks.2.norm1.bias\n", "up_blocks.1.attentions.1.transformer_blocks.2.norm1.weight\n", "up_blocks.1.attentions.1.transformer_blocks.2.norm2.bias\n", "up_blocks.1.attentions.1.transformer_blocks.2.norm2.weight\n", "up_blocks.1.attentions.1.transformer_blocks.2.norm3.bias\n", "up_blocks.1.attentions.1.transformer_blocks.2.norm3.weight\n", "up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_k.weight\n", "up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_out.0.bias\n", "up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_out.0.weight\n", "up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_q.weight\n", "up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_v.weight\n", "up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_k.weight\n", "up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_out.0.bias\n", "up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_out.0.weight\n", "up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_q.weight\n", "up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_v.weight\n", "up_blocks.1.attentions.2.transformer_blocks.2.ff.net.0.proj.bias\n", "up_blocks.1.attentions.2.transformer_blocks.2.ff.net.0.proj.weight\n", "up_blocks.1.attentions.2.transformer_blocks.2.ff.net.2.bias\n", "up_blocks.1.attentions.2.transformer_blocks.2.ff.net.2.weight\n", "up_blocks.1.attentions.2.transformer_blocks.2.norm1.bias\n", "up_blocks.1.attentions.2.transformer_blocks.2.norm1.weight\n", "up_blocks.1.attentions.2.transformer_blocks.2.norm2.bias\n", "up_blocks.1.attentions.2.transformer_blocks.2.norm2.weight\n", "up_blocks.1.attentions.2.transformer_blocks.2.norm3.bias\n", "up_blocks.1.attentions.2.transformer_blocks.2.norm3.weight\n", "up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_k.weight\n", "up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.bias\n", "up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.weight\n", "up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_q.weight\n", "up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_v.weight\n", "up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_k.weight\n", "up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.bias\n", "up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.weight\n", "up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_q.weight\n", "up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_v.weight\n", "up_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.bias\n", "up_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.weight\n", "up_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.bias\n", "up_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.weight\n", "up_blocks.2.attentions.0.transformer_blocks.2.norm1.bias\n", "up_blocks.2.attentions.0.transformer_blocks.2.norm1.weight\n", "up_blocks.2.attentions.0.transformer_blocks.2.norm2.bias\n", "up_blocks.2.attentions.0.transformer_blocks.2.norm2.weight\n", "up_blocks.2.attentions.0.transformer_blocks.2.norm3.bias\n", "up_blocks.2.attentions.0.transformer_blocks.2.norm3.weight\n", "up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_k.weight\n", "up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.bias\n", "up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.weight\n", "up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_q.weight\n", "up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_v.weight\n", "up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_k.weight\n", "up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.bias\n", "up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.weight\n", "up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_q.weight\n", "up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_v.weight\n", "up_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.bias\n", "up_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.weight\n", "up_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.bias\n", "up_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.weight\n", "up_blocks.2.attentions.1.transformer_blocks.2.norm1.bias\n", "up_blocks.2.attentions.1.transformer_blocks.2.norm1.weight\n", "up_blocks.2.attentions.1.transformer_blocks.2.norm2.bias\n", "up_blocks.2.attentions.1.transformer_blocks.2.norm2.weight\n", "up_blocks.2.attentions.1.transformer_blocks.2.norm3.bias\n", "up_blocks.2.attentions.1.transformer_blocks.2.norm3.weight\n", "up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_k.weight\n", "up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_out.0.bias\n", "up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_out.0.weight\n", "up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_q.weight\n", "up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_v.weight\n", "up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_k.weight\n", "up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_out.0.bias\n", "up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_out.0.weight\n", "up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_q.weight\n", "up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_v.weight\n", "up_blocks.2.attentions.2.transformer_blocks.2.ff.net.0.proj.bias\n", "up_blocks.2.attentions.2.transformer_blocks.2.ff.net.0.proj.weight\n", "up_blocks.2.attentions.2.transformer_blocks.2.ff.net.2.bias\n", "up_blocks.2.attentions.2.transformer_blocks.2.ff.net.2.weight\n", "up_blocks.2.attentions.2.transformer_blocks.2.norm1.bias\n", "up_blocks.2.attentions.2.transformer_blocks.2.norm1.weight\n", "up_blocks.2.attentions.2.transformer_blocks.2.norm2.bias\n", "up_blocks.2.attentions.2.transformer_blocks.2.norm2.weight\n", "up_blocks.2.attentions.2.transformer_blocks.2.norm3.bias\n", "up_blocks.2.attentions.2.transformer_blocks.2.norm3.weight\n", "up_blocks.3.attentions.0.transformer_blocks.2.attn1.to_k.weight\n", "up_blocks.3.attentions.0.transformer_blocks.2.attn1.to_out.0.bias\n", "up_blocks.3.attentions.0.transformer_blocks.2.attn1.to_out.0.weight\n", "up_blocks.3.attentions.0.transformer_blocks.2.attn1.to_q.weight\n", "up_blocks.3.attentions.0.transformer_blocks.2.attn1.to_v.weight\n", "up_blocks.3.attentions.0.transformer_blocks.2.attn2.to_k.weight\n", "up_blocks.3.attentions.0.transformer_blocks.2.attn2.to_out.0.bias\n", "up_blocks.3.attentions.0.transformer_blocks.2.attn2.to_out.0.weight\n", "up_blocks.3.attentions.0.transformer_blocks.2.attn2.to_q.weight\n", "up_blocks.3.attentions.0.transformer_blocks.2.attn2.to_v.weight\n", "up_blocks.3.attentions.0.transformer_blocks.2.ff.net.0.proj.bias\n", "up_blocks.3.attentions.0.transformer_blocks.2.ff.net.0.proj.weight\n", "up_blocks.3.attentions.0.transformer_blocks.2.ff.net.2.bias\n", "up_blocks.3.attentions.0.transformer_blocks.2.ff.net.2.weight\n", "up_blocks.3.attentions.0.transformer_blocks.2.norm1.bias\n", "up_blocks.3.attentions.0.transformer_blocks.2.norm1.weight\n", "up_blocks.3.attentions.0.transformer_blocks.2.norm2.bias\n", "up_blocks.3.attentions.0.transformer_blocks.2.norm2.weight\n", "up_blocks.3.attentions.0.transformer_blocks.2.norm3.bias\n", "up_blocks.3.attentions.0.transformer_blocks.2.norm3.weight\n", "up_blocks.3.attentions.1.transformer_blocks.2.attn1.to_k.weight\n", "up_blocks.3.attentions.1.transformer_blocks.2.attn1.to_out.0.bias\n", "up_blocks.3.attentions.1.transformer_blocks.2.attn1.to_out.0.weight\n", "up_blocks.3.attentions.1.transformer_blocks.2.attn1.to_q.weight\n", "up_blocks.3.attentions.1.transformer_blocks.2.attn1.to_v.weight\n", "up_blocks.3.attentions.1.transformer_blocks.2.attn2.to_k.weight\n", "up_blocks.3.attentions.1.transformer_blocks.2.attn2.to_out.0.bias\n", "up_blocks.3.attentions.1.transformer_blocks.2.attn2.to_out.0.weight\n", "up_blocks.3.attentions.1.transformer_blocks.2.attn2.to_q.weight\n", "up_blocks.3.attentions.1.transformer_blocks.2.attn2.to_v.weight\n", "up_blocks.3.attentions.1.transformer_blocks.2.ff.net.0.proj.bias\n", "up_blocks.3.attentions.1.transformer_blocks.2.ff.net.0.proj.weight\n", "up_blocks.3.attentions.1.transformer_blocks.2.ff.net.2.bias\n", "up_blocks.3.attentions.1.transformer_blocks.2.ff.net.2.weight\n", "up_blocks.3.attentions.1.transformer_blocks.2.norm1.bias\n", "up_blocks.3.attentions.1.transformer_blocks.2.norm1.weight\n", "up_blocks.3.attentions.1.transformer_blocks.2.norm2.bias\n", "up_blocks.3.attentions.1.transformer_blocks.2.norm2.weight\n", "up_blocks.3.attentions.1.transformer_blocks.2.norm3.bias\n", "up_blocks.3.attentions.1.transformer_blocks.2.norm3.weight\n", "up_blocks.3.attentions.2.transformer_blocks.2.attn1.to_k.weight\n", "up_blocks.3.attentions.2.transformer_blocks.2.attn1.to_out.0.bias\n", "up_blocks.3.attentions.2.transformer_blocks.2.attn1.to_out.0.weight\n", "up_blocks.3.attentions.2.transformer_blocks.2.attn1.to_q.weight\n", "up_blocks.3.attentions.2.transformer_blocks.2.attn1.to_v.weight\n", "up_blocks.3.attentions.2.transformer_blocks.2.attn2.to_k.weight\n", "up_blocks.3.attentions.2.transformer_blocks.2.attn2.to_out.0.bias\n", "up_blocks.3.attentions.2.transformer_blocks.2.attn2.to_out.0.weight\n", "up_blocks.3.attentions.2.transformer_blocks.2.attn2.to_q.weight\n", "up_blocks.3.attentions.2.transformer_blocks.2.attn2.to_v.weight\n", "up_blocks.3.attentions.2.transformer_blocks.2.ff.net.0.proj.bias\n", "up_blocks.3.attentions.2.transformer_blocks.2.ff.net.0.proj.weight\n", "up_blocks.3.attentions.2.transformer_blocks.2.ff.net.2.bias\n", "up_blocks.3.attentions.2.transformer_blocks.2.ff.net.2.weight\n", "up_blocks.3.attentions.2.transformer_blocks.2.norm1.bias\n", "up_blocks.3.attentions.2.transformer_blocks.2.norm1.weight\n", "up_blocks.3.attentions.2.transformer_blocks.2.norm2.bias\n", "up_blocks.3.attentions.2.transformer_blocks.2.norm2.weight\n", "up_blocks.3.attentions.2.transformer_blocks.2.norm3.bias\n", "up_blocks.3.attentions.2.transformer_blocks.2.norm3.weight\n", "UNet2DConditionModel(\n", " (conv_in): Conv2d(16, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_proj): Timesteps()\n", " (time_embedding): TimestepEmbedding(\n", " (linear_1): Linear(in_features=320, out_features=1280, bias=True)\n", " (act): SiLU()\n", " (linear_2): Linear(in_features=1280, out_features=1280, bias=True)\n", " )\n", " (down_blocks): ModuleList(\n", " (0): CrossAttnDownBlock2D(\n", " (attentions): ModuleList(\n", " (0-1): 2 x Transformer2DModel(\n", " (norm): GroupNorm(32, 320, eps=1e-06, affine=True)\n", " (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))\n", " (transformer_blocks): ModuleList(\n", " (0-2): 3 x BasicTransformerBlock(\n", " (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)\n", " (attn1): Attention(\n", " (to_q): Linear(in_features=320, out_features=320, bias=False)\n", " (to_k): Linear(in_features=320, out_features=320, bias=False)\n", " (to_v): Linear(in_features=320, out_features=320, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=320, out_features=320, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)\n", " (attn2): Attention(\n", " (to_q): Linear(in_features=320, out_features=320, bias=False)\n", " (to_k): Linear(in_features=768, out_features=320, bias=False)\n", " (to_v): Linear(in_features=768, out_features=320, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=320, out_features=320, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)\n", " (ff): FeedForward(\n", " (net): ModuleList(\n", " (0): GEGLU(\n", " (proj): Linear(in_features=320, out_features=2560, bias=True)\n", " )\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=1280, out_features=320, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0-1): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 320, eps=1e-05, affine=True)\n", " (conv1): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1280, out_features=320, bias=True)\n", " (norm2): GroupNorm(32, 320, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (downsamplers): ModuleList(\n", " (0): Downsample2D(\n", " (conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (1): CrossAttnDownBlock2D(\n", " (attentions): ModuleList(\n", " (0-1): 2 x Transformer2DModel(\n", " (norm): GroupNorm(32, 640, eps=1e-06, affine=True)\n", " (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))\n", " (transformer_blocks): ModuleList(\n", " (0-2): 3 x BasicTransformerBlock(\n", " (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)\n", " (attn1): Attention(\n", " (to_q): Linear(in_features=640, out_features=640, bias=False)\n", " (to_k): Linear(in_features=640, out_features=640, bias=False)\n", " (to_v): Linear(in_features=640, out_features=640, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=640, out_features=640, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)\n", " (attn2): Attention(\n", " (to_q): Linear(in_features=640, out_features=640, bias=False)\n", " (to_k): Linear(in_features=768, out_features=640, bias=False)\n", " (to_v): Linear(in_features=768, out_features=640, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=640, out_features=640, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)\n", " (ff): FeedForward(\n", " (net): ModuleList(\n", " (0): GEGLU(\n", " (proj): Linear(in_features=640, out_features=5120, bias=True)\n", " )\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=2560, out_features=640, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 320, eps=1e-05, affine=True)\n", " (conv1): Conv2d(320, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)\n", " (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(320, 640, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 640, eps=1e-05, affine=True)\n", " (conv1): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)\n", " (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (downsamplers): ModuleList(\n", " (0): Downsample2D(\n", " (conv): Conv2d(640, 640, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (2): CrossAttnDownBlock2D(\n", " (attentions): ModuleList(\n", " (0-1): 2 x Transformer2DModel(\n", " (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)\n", " (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))\n", " (transformer_blocks): ModuleList(\n", " (0-2): 3 x BasicTransformerBlock(\n", " (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", " (attn1): Attention(\n", " (to_q): Linear(in_features=1280, out_features=1280, bias=False)\n", " (to_k): Linear(in_features=1280, out_features=1280, bias=False)\n", " (to_v): Linear(in_features=1280, out_features=1280, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=1280, out_features=1280, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", " (attn2): Attention(\n", " (to_q): Linear(in_features=1280, out_features=1280, bias=False)\n", " (to_k): Linear(in_features=768, out_features=1280, bias=False)\n", " (to_v): Linear(in_features=768, out_features=1280, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=1280, out_features=1280, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", " (ff): FeedForward(\n", " (net): ModuleList(\n", " (0): GEGLU(\n", " (proj): Linear(in_features=1280, out_features=10240, bias=True)\n", " )\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=5120, out_features=1280, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 640, eps=1e-05, affine=True)\n", " (conv1): Conv2d(640, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)\n", " (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(640, 1280, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)\n", " (conv1): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)\n", " (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (downsamplers): ModuleList(\n", " (0): Downsample2D(\n", " (conv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (3): DownBlock2D(\n", " (resnets): ModuleList(\n", " (0-1): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)\n", " (conv1): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)\n", " (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " )\n", " )\n", " (up_blocks): ModuleList(\n", " (0): UpBlock2D(\n", " (resnets): ModuleList(\n", " (0-2): 3 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 2560, eps=1e-05, affine=True)\n", " (conv1): Conv2d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)\n", " (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " )\n", " (upsamplers): ModuleList(\n", " (0): Upsample2D(\n", " (conv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (1): CrossAttnUpBlock2D(\n", " (attentions): ModuleList(\n", " (0-2): 3 x Transformer2DModel(\n", " (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)\n", " (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))\n", " (transformer_blocks): ModuleList(\n", " (0-2): 3 x BasicTransformerBlock(\n", " (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", " (attn1): Attention(\n", " (to_q): Linear(in_features=1280, out_features=1280, bias=False)\n", " (to_k): Linear(in_features=1280, out_features=1280, bias=False)\n", " (to_v): Linear(in_features=1280, out_features=1280, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=1280, out_features=1280, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", " (attn2): Attention(\n", " (to_q): Linear(in_features=1280, out_features=1280, bias=False)\n", " (to_k): Linear(in_features=768, out_features=1280, bias=False)\n", " (to_v): Linear(in_features=768, out_features=1280, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=1280, out_features=1280, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", " (ff): FeedForward(\n", " (net): ModuleList(\n", " (0): GEGLU(\n", " (proj): Linear(in_features=1280, out_features=10240, bias=True)\n", " )\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=5120, out_features=1280, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0-1): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 2560, eps=1e-05, affine=True)\n", " (conv1): Conv2d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)\n", " (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (2): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 1920, eps=1e-05, affine=True)\n", " (conv1): Conv2d(1920, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)\n", " (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(1920, 1280, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " )\n", " (upsamplers): ModuleList(\n", " (0): Upsample2D(\n", " (conv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (2): CrossAttnUpBlock2D(\n", " (attentions): ModuleList(\n", " (0-2): 3 x Transformer2DModel(\n", " (norm): GroupNorm(32, 640, eps=1e-06, affine=True)\n", " (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))\n", " (transformer_blocks): ModuleList(\n", " (0-2): 3 x BasicTransformerBlock(\n", " (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)\n", " (attn1): Attention(\n", " (to_q): Linear(in_features=640, out_features=640, bias=False)\n", " (to_k): Linear(in_features=640, out_features=640, bias=False)\n", " (to_v): Linear(in_features=640, out_features=640, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=640, out_features=640, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)\n", " (attn2): Attention(\n", " (to_q): Linear(in_features=640, out_features=640, bias=False)\n", " (to_k): Linear(in_features=768, out_features=640, bias=False)\n", " (to_v): Linear(in_features=768, out_features=640, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=640, out_features=640, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)\n", " (ff): FeedForward(\n", " (net): ModuleList(\n", " (0): GEGLU(\n", " (proj): Linear(in_features=640, out_features=5120, bias=True)\n", " )\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=2560, out_features=640, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 1920, eps=1e-05, affine=True)\n", " (conv1): Conv2d(1920, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)\n", " (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(1920, 640, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)\n", " (conv1): Conv2d(1280, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)\n", " (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(1280, 640, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (2): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 960, eps=1e-05, affine=True)\n", " (conv1): Conv2d(960, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)\n", " (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(960, 640, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " )\n", " (upsamplers): ModuleList(\n", " (0): Upsample2D(\n", " (conv): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (3): CrossAttnUpBlock2D(\n", " (attentions): ModuleList(\n", " (0-2): 3 x Transformer2DModel(\n", " (norm): GroupNorm(32, 320, eps=1e-06, affine=True)\n", " (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))\n", " (transformer_blocks): ModuleList(\n", " (0-2): 3 x BasicTransformerBlock(\n", " (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)\n", " (attn1): Attention(\n", " (to_q): Linear(in_features=320, out_features=320, bias=False)\n", " (to_k): Linear(in_features=320, out_features=320, bias=False)\n", " (to_v): Linear(in_features=320, out_features=320, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=320, out_features=320, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)\n", " (attn2): Attention(\n", " (to_q): Linear(in_features=320, out_features=320, bias=False)\n", " (to_k): Linear(in_features=768, out_features=320, bias=False)\n", " (to_v): Linear(in_features=768, out_features=320, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=320, out_features=320, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)\n", " (ff): FeedForward(\n", " (net): ModuleList(\n", " (0): GEGLU(\n", " (proj): Linear(in_features=320, out_features=2560, bias=True)\n", " )\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=1280, out_features=320, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 960, eps=1e-05, affine=True)\n", " (conv1): Conv2d(960, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1280, out_features=320, bias=True)\n", " (norm2): GroupNorm(32, 320, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1-2): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 640, eps=1e-05, affine=True)\n", " (conv1): Conv2d(640, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1280, out_features=320, bias=True)\n", " (norm2): GroupNorm(32, 320, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(640, 320, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " )\n", " )\n", " )\n", " (mid_block): UNetMidBlock2DCrossAttn(\n", " (attentions): ModuleList(\n", " (0): Transformer2DModel(\n", " (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)\n", " (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))\n", " (transformer_blocks): ModuleList(\n", " (0-2): 3 x BasicTransformerBlock(\n", " (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", " (attn1): Attention(\n", " (to_q): Linear(in_features=1280, out_features=1280, bias=False)\n", " (to_k): Linear(in_features=1280, out_features=1280, bias=False)\n", " (to_v): Linear(in_features=1280, out_features=1280, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=1280, out_features=1280, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", " (attn2): Attention(\n", " (to_q): Linear(in_features=1280, out_features=1280, bias=False)\n", " (to_k): Linear(in_features=768, out_features=1280, bias=False)\n", " (to_v): Linear(in_features=768, out_features=1280, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=1280, out_features=1280, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", " (ff): FeedForward(\n", " (net): ModuleList(\n", " (0): GEGLU(\n", " (proj): Linear(in_features=1280, out_features=10240, bias=True)\n", " )\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=5120, out_features=1280, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0-1): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)\n", " (conv1): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)\n", " (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " )\n", " (conv_norm_out): GroupNorm(32, 320, eps=1e-05, affine=True)\n", " (conv_act): SiLU()\n", " (conv_out): Conv2d(320, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", ")\n" ] } ], "source": [ "import torch\n", "from diffusers import UNet2DConditionModel\n", "from tqdm import tqdm\n", "\n", "def log(message):\n", " print(message)\n", "\n", "def main():\n", " checkpoint_path_old = \"unet\"\n", " checkpoint_path_new = \"sd15_tmp\"\n", " device = \"cuda\"\n", " dtype = torch.float16\n", "\n", " # Загрузка моделей\n", " old_unet = UNet2DConditionModel.from_pretrained(checkpoint_path_old).to(device, dtype=dtype)\n", " new_unet = UNet2DConditionModel.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)\n", "\n", " old_state_dict = old_unet.state_dict()\n", " new_state_dict = new_unet.state_dict()\n", "\n", " transferred_state_dict = {}\n", " transfer_stats = {\n", " \"перенесено\": 0,\n", " \"несовпадение_размеров\": 0,\n", " \"пропущено\": 0\n", " }\n", "\n", " transferred_keys = set()\n", "\n", " # Обрабатываем каждый ключ старой модели\n", " for old_key in tqdm(old_state_dict.keys(), desc=\"Перенос весов\"):\n", " new_key = old_key\n", "\n", " # Проверяем, существует ли ключ в новой модели\n", " if new_key in new_state_dict:\n", " # Проверяем совместимость размеров\n", " if old_state_dict[old_key].shape == new_state_dict[new_key].shape:\n", " transferred_state_dict[new_key] = old_state_dict[old_key].clone()\n", " transferred_keys.add(new_key)\n", " transfer_stats[\"перенесено\"] += 1\n", " #log(f\"✓ Перенос: {old_key} -> {new_key}, форма: {old_state_dict[old_key].shape}\")\n", " else:\n", " log(f\"✗ Несовпадение размеров: {old_key} ({old_state_dict[old_key].shape}) -> {new_key} ({new_state_dict[new_key].shape})\")\n", " transfer_stats[\"несовпадение_размеров\"] += 1\n", " else:\n", " log(f\"? Ключ не найден в новой модели: {old_key} -> {old_state_dict[old_key].shape}\")\n", " transfer_stats[\"пропущено\"] += 1\n", "\n", " # Обновляем состояние новой модели перенесенными весами\n", " new_state_dict.update(transferred_state_dict)\n", " new_unet.load_state_dict(new_state_dict)\n", " new_unet.save_pretrained(\"unet_1.3b\")\n", "\n", " # Получаем список неперенесенных ключей\n", " non_transferred_keys = sorted(set(new_state_dict.keys()) - transferred_keys)\n", "\n", " print(\"Статистика переноса:\", transfer_stats)\n", " print(\"Неперенесенные ключи в новой модели:\")\n", " for key in non_transferred_keys:\n", " print(key)\n", "\n", " print(new_unet)\n", "\n", "if __name__ == \"__main__\":\n", " main()\n", "# Статистика переноса: {'перенесено': 686, 'несовпадение_размеров': 0, 'пропущено': 0}" ] }, { "cell_type": "code", "execution_count": null, "id": "f2438e3d-4b83-4b3f-8e78-53cbcc35f6e4", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }