File size: 8,504 Bytes
57e54f6 412fd95 57e54f6 2efd7f8 57e54f6 9888744 57e54f6 2efd7f8 412fd95 2efd7f8 57e54f6 8ceafcb 57e54f6 412fd95 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | import torch
from PIL import Image
import numpy as np
from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel
from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, EulerAncestralDiscreteScheduler
import os
import logging
import time
import random # 补充导入
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelManager:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"使用设备: {self.device}")
# 模型配置 - 使用更精细的3D模型
self.model_config = {
"caption_model": "Salesforce/blip-image-captioning-large",
"clip_model": "openai/clip-vit-large-patch14",
"sd_model": "runwayml/stable-diffusion-v1-5",
"controlnet_model": "lllyasviel/control_v11p_sd15_openpose"
}
# 模型容器
self.caption_processor = None
self.caption_model = None
self.clip_processor = None
self.clip_model = None
self.sd_pipeline = None
self.controlnet = None
self.controlnet_pipeline = None
# 预加载所有模型
self.load_all_models()
def load_all_models(self):
self.load_caption_model()
self.load_clip_model()
self.load_sd_pipeline()
self.load_controlnet_pipeline()
def load_caption_model(self):
try:
logger.info("加载 BLIP 图像描述模型...")
self.caption_processor = BlipProcessor.from_pretrained(
self.model_config["caption_model"],
cache_dir="/tmp/models"
)
self.caption_model = BlipForConditionalGeneration.from_pretrained(
self.model_config["caption_model"],
cache_dir="/tmp/models",
torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
).to(self.device)
logger.info("BLIP 模型加载完成")
except Exception as e:
logger.error(f"BLIP 模型加载失败: {e}")
def load_clip_model(self):
try:
logger.info("加载 CLIP 模型...")
self.clip_processor = CLIPProcessor.from_pretrained(
self.model_config["clip_model"],
cache_dir="/tmp/models"
)
self.clip_model = CLIPModel.from_pretrained(
self.model_config["clip_model"],
cache_dir="/tmp/models",
torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
).to(self.device)
logger.info("CLIP 模型加载完成")
except Exception as e:
logger.error(f"CLIP 模型加载失败: {e}")
def load_sd_pipeline(self):
try:
logger.info("加载 Stable Diffusion Pipeline...")
self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
self.model_config["sd_model"],
torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
cache_dir="/tmp/models",
safety_checker=None,
use_safetensors=True
)
self.sd_pipeline = self.sd_pipeline.to(self.device)
self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.sd_pipeline.scheduler.config)
logger.info("Stable Diffusion Pipeline 加载完成")
except Exception as e:
logger.error(f"Stable Diffusion Pipeline 加载失败: {e}")
def load_controlnet_pipeline(self):
try:
logger.info("加载 ControlNet 模型和 Pipeline...")
self.controlnet = ControlNetModel.from_pretrained(
self.model_config["controlnet_model"],
cache_dir="/tmp/models",
torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
).to(self.device)
self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
self.model_config["sd_model"],
controlnet=self.controlnet,
cache_dir="/tmp/models",
torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
safety_checker=None
).to(self.device)
self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.controlnet_pipeline.scheduler.config)
logger.info("ControlNet Pipeline 加载完成")
except Exception as e:
logger.error(f"ControlNet Pipeline 加载失败: {e}")
# 下面是真正调用模型的接口
def generate_caption(self, image):
"""使用BLIP模型生成图像描述"""
if self.caption_model is None or self.caption_processor is None:
self.load_caption_model()
inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.caption_model.generate(**inputs, max_length=50)
caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True)
return caption
def analyze_style(self, image):
"""使用CLIP模型分析服装风格"""
if self.clip_model is None or self.clip_processor is None:
self.load_clip_model()
styles = ["商务正装", "休闲风", "运动风", "时尚潮流", "复古风", "街头风", "优雅风"]
inputs = self.clip_processor(
text=styles,
images=image,
return_tensors="pt",
padding=True
).to(self.device)
with torch.no_grad():
outputs = self.clip_model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1).cpu().numpy()[0]
style_scores = {style: float(prob) for style, prob in zip(styles, probs)}
return style_scores
def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512):
"""使用Stable Diffusion生成设计图像"""
if self.sd_pipeline is None:
self.load_sd_pipeline()
if self.sd_pipeline is None:
logger.error("无法生成图像:Stable Diffusion 模型未加载")
return self.create_placeholder_image(width, height)
result = self.sd_pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
height=height,
width=width
)
return result.images[0]
def generate_controlnet_image(self, image, prompt, reference_image=None, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0):
"""使用ControlNet生成3D试穿效果 - 更精细的模型"""
if self.controlnet_pipeline is None:
self.load_controlnet_pipeline()
if self.controlnet_pipeline is None:
logger.error("无法生成3D试穿:ControlNet 模型未加载")
return self.create_placeholder_image(512, 768)
if reference_image is not None:
prompt = f"{prompt}, based on reference design"
result = self.controlnet_pipeline(
prompt=prompt,
image=image,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
)
return result.images[0]
def create_placeholder_image(self, width, height):
"""创建占位图像"""
color = (random.randint(120, 200), random.randint(120, 200), random.randint(120, 200))
return Image.new('RGB', (width, height), color=color)
def cleanup(self):
"""释放模型占用显存和缓存"""
logger.info("释放模型占用显存和缓存...")
try:
del self.caption_model
del self.clip_model
del self.sd_pipeline
del self.controlnet
del self.controlnet_pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("显存和缓存清理完成")
except Exception as e:
logger.error(f"清理显存失败: {e}")
|