Upload ./infer_inpaint.py with huggingface_hub
Browse files- infer_inpaint.py +97 -0
infer_inpaint.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import pdb
|
| 5 |
+
|
| 6 |
+
from peft import LoraConfig, get_peft_model
|
| 7 |
+
import torch
|
| 8 |
+
from safetensors.torch import load_model, save_model
|
| 9 |
+
from marigold.marigold_inpaint_pipeline import MarigoldInpaintPipeline
|
| 10 |
+
from marigold.duplicate_unet import DoubleUNet2DConditionModel
|
| 11 |
+
import json
|
| 12 |
+
from depth_anything_v2.dpt import DepthAnythingV2
|
| 13 |
+
from torchvision.transforms.functional import pil_to_tensor
|
| 14 |
+
from PIL import Image
|
| 15 |
+
import random
|
| 16 |
+
import numpy as np
|
| 17 |
+
from pycocotools import mask as coco_mask
|
| 18 |
+
from diffusers.schedulers import DDIMScheduler, PNDMScheduler
|
| 19 |
+
from torchvision.transforms import InterpolationMode, Resize, CenterCrop
|
| 20 |
+
import torchvision.transforms as transforms
|
| 21 |
+
|
| 22 |
+
model = MarigoldInpaintPipeline.from_pretrained('stabilityai/stable-diffusion-2')
|
| 23 |
+
unet_config_path = '/home/aiops/wangzh/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2/snapshots/1e128c8891e52218b74cde8f26dbfc701cb99d79/unet/config.json'
|
| 24 |
+
# unet_checkpoint_path = '/home/aiops/wangzh/marigold/768_gen/diffusion_pytorch_model.safetensors'
|
| 25 |
+
model.unet = DoubleUNet2DConditionModel(**json.load(open(unet_config_path)))
|
| 26 |
+
# model.unet.load_state_dict(torch.load(unet_checkpoint_path, map_location='cpu'), strict=False)
|
| 27 |
+
|
| 28 |
+
model.unet.config["in_channels"] = 13
|
| 29 |
+
model.unet.duplicate_model()
|
| 30 |
+
model.unet.inpaint_rgb_conv_in()
|
| 31 |
+
model.unet.inpaint_depth_conv_in()
|
| 32 |
+
|
| 33 |
+
unet_lora_config = LoraConfig(
|
| 34 |
+
r=128,
|
| 35 |
+
lora_alpha=128,
|
| 36 |
+
init_lora_weights="gaussian",
|
| 37 |
+
target_modules=['to_k','to_q','to_v','to_out.0'],
|
| 38 |
+
)
|
| 39 |
+
model.unet = get_peft_model(model.unet, unet_lora_config)
|
| 40 |
+
|
| 41 |
+
sd2inpaint_ckpt = torch.load('/home/aiops/wangzh/marigold/output/512-inpaint-0.5-128-vitl-partition/checkpoint/latest/pytorch_model.bin', map_location='cpu')
|
| 42 |
+
model.unet.load_state_dict(sd2inpaint_ckpt)
|
| 43 |
+
model.to('cuda')
|
| 44 |
+
|
| 45 |
+
model_configs = {
|
| 46 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
| 47 |
+
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
| 48 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
| 49 |
+
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
model.rgb_scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")
|
| 53 |
+
model.depth_scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")
|
| 54 |
+
|
| 55 |
+
depth_model = DepthAnythingV2(**model_configs['vitl'])
|
| 56 |
+
depth_model.load_state_dict(
|
| 57 |
+
torch.load(f'/home/aiops/wangzh/Depth-Anything-V2/checkpoints/depth_anything_v2_vitl.pth', map_location='cpu'))
|
| 58 |
+
depth_model = depth_model.to('cuda').eval()
|
| 59 |
+
|
| 60 |
+
image_path = ['/dataset/~sa-1b/data/sa_001000/sa_10000335.jpg',
|
| 61 |
+
'/dataset/~sa-1b/data/sa_000357/sa_3572319.jpg',
|
| 62 |
+
'/dataset/~sa-1b/data/sa_000045/sa_457934.jpg']
|
| 63 |
+
|
| 64 |
+
prompt = ['A white car is parked in front of the factory',
|
| 65 |
+
'church with cemetery next to it',
|
| 66 |
+
'A house with a red brick roof']
|
| 67 |
+
|
| 68 |
+
imgs = [pil_to_tensor(Image.open(p)) for p in image_path]
|
| 69 |
+
depth_imgs = [depth_model(img.unsqueeze(0).cpu().numpy()) for img in imgs]
|
| 70 |
+
|
| 71 |
+
masks = []
|
| 72 |
+
for rgb_path in image_path:
|
| 73 |
+
anno = json.load(open(rgb_path.replace('.jpg', '.json')))['annotations']
|
| 74 |
+
random.shuffle(anno)
|
| 75 |
+
object_num = random.randint(5, 10)
|
| 76 |
+
mask = np.array(coco_mask.decode(anno[0]['segmentation']), dtype=np.uint8)
|
| 77 |
+
for single_anno in (anno[0:object_num] if len(anno)>object_num else anno):
|
| 78 |
+
mask += np.array(coco_mask.decode(single_anno['segmentation']), dtype=np.uint8)
|
| 79 |
+
mask = mask
|
| 80 |
+
mask = torch.stack([torch.tensor(mask) * 3], dim=0)
|
| 81 |
+
masks.append(mask)
|
| 82 |
+
|
| 83 |
+
# mask = torch.zeros((512,512))
|
| 84 |
+
# mask[100:300, 200:400] = 1
|
| 85 |
+
# masks.append(mask)
|
| 86 |
+
|
| 87 |
+
resize_transform = Resize(size=[512, 512], interpolation=InterpolationMode.NEAREST_EXACT)
|
| 88 |
+
imgs = [resize_transform(img) for img in imgs]
|
| 89 |
+
depth_imgs = [resize_transform(depth_img.unsqueeze(0)) for depth_img in depth_imgs]
|
| 90 |
+
masks = [resize_transform(mask.unsqueeze(0)) for mask in masks]
|
| 91 |
+
|
| 92 |
+
# for gs in [1,2,3,4,5]:
|
| 93 |
+
for i in range(len(imgs)):
|
| 94 |
+
output_image = model._rgbd_inpaint(imgs[i], depth_imgs[i].unsqueeze(0), masks[i], [prompt[i]], processing_res=512,
|
| 95 |
+
guidance_scale=3, mode='joint_inpaint' #'full_rgb_depth_inpaint', 'full_depth_rgb_inpaint', 'joint_inpaint'
|
| 96 |
+
)
|
| 97 |
+
output_image.save(f'./joint-{i}.jpg')
|