newGPU / models /model_manager.py
Humphreykowl's picture
Update models/model_manager.py
412fd95 verified
raw
history blame
8.5 kB
import torch
from PIL import Image
import numpy as np
from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel
from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, EulerAncestralDiscreteScheduler
import os
import logging
import time
import random # 补充导入
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelManager:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"使用设备: {self.device}")
# 模型配置 - 使用更精细的3D模型
self.model_config = {
"caption_model": "Salesforce/blip-image-captioning-large",
"clip_model": "openai/clip-vit-large-patch14",
"sd_model": "runwayml/stable-diffusion-v1-5",
"controlnet_model": "lllyasviel/control_v11p_sd15_openpose"
}
# 模型容器
self.caption_processor = None
self.caption_model = None
self.clip_processor = None
self.clip_model = None
self.sd_pipeline = None
self.controlnet = None
self.controlnet_pipeline = None
# 预加载所有模型
self.load_all_models()
def load_all_models(self):
self.load_caption_model()
self.load_clip_model()
self.load_sd_pipeline()
self.load_controlnet_pipeline()
def load_caption_model(self):
try:
logger.info("加载 BLIP 图像描述模型...")
self.caption_processor = BlipProcessor.from_pretrained(
self.model_config["caption_model"],
cache_dir="/tmp/models"
)
self.caption_model = BlipForConditionalGeneration.from_pretrained(
self.model_config["caption_model"],
cache_dir="/tmp/models",
torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
).to(self.device)
logger.info("BLIP 模型加载完成")
except Exception as e:
logger.error(f"BLIP 模型加载失败: {e}")
def load_clip_model(self):
try:
logger.info("加载 CLIP 模型...")
self.clip_processor = CLIPProcessor.from_pretrained(
self.model_config["clip_model"],
cache_dir="/tmp/models"
)
self.clip_model = CLIPModel.from_pretrained(
self.model_config["clip_model"],
cache_dir="/tmp/models",
torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
).to(self.device)
logger.info("CLIP 模型加载完成")
except Exception as e:
logger.error(f"CLIP 模型加载失败: {e}")
def load_sd_pipeline(self):
try:
logger.info("加载 Stable Diffusion Pipeline...")
self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
self.model_config["sd_model"],
torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
cache_dir="/tmp/models",
safety_checker=None,
use_safetensors=True
)
self.sd_pipeline = self.sd_pipeline.to(self.device)
self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.sd_pipeline.scheduler.config)
logger.info("Stable Diffusion Pipeline 加载完成")
except Exception as e:
logger.error(f"Stable Diffusion Pipeline 加载失败: {e}")
def load_controlnet_pipeline(self):
try:
logger.info("加载 ControlNet 模型和 Pipeline...")
self.controlnet = ControlNetModel.from_pretrained(
self.model_config["controlnet_model"],
cache_dir="/tmp/models",
torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
).to(self.device)
self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
self.model_config["sd_model"],
controlnet=self.controlnet,
cache_dir="/tmp/models",
torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
safety_checker=None
).to(self.device)
self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.controlnet_pipeline.scheduler.config)
logger.info("ControlNet Pipeline 加载完成")
except Exception as e:
logger.error(f"ControlNet Pipeline 加载失败: {e}")
# 下面是真正调用模型的接口
def generate_caption(self, image):
"""使用BLIP模型生成图像描述"""
if self.caption_model is None or self.caption_processor is None:
self.load_caption_model()
inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.caption_model.generate(**inputs, max_length=50)
caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True)
return caption
def analyze_style(self, image):
"""使用CLIP模型分析服装风格"""
if self.clip_model is None or self.clip_processor is None:
self.load_clip_model()
styles = ["商务正装", "休闲风", "运动风", "时尚潮流", "复古风", "街头风", "优雅风"]
inputs = self.clip_processor(
text=styles,
images=image,
return_tensors="pt",
padding=True
).to(self.device)
with torch.no_grad():
outputs = self.clip_model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1).cpu().numpy()[0]
style_scores = {style: float(prob) for style, prob in zip(styles, probs)}
return style_scores
def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512):
"""使用Stable Diffusion生成设计图像"""
if self.sd_pipeline is None:
self.load_sd_pipeline()
if self.sd_pipeline is None:
logger.error("无法生成图像:Stable Diffusion 模型未加载")
return self.create_placeholder_image(width, height)
result = self.sd_pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
height=height,
width=width
)
return result.images[0]
def generate_controlnet_image(self, image, prompt, reference_image=None, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0):
"""使用ControlNet生成3D试穿效果 - 更精细的模型"""
if self.controlnet_pipeline is None:
self.load_controlnet_pipeline()
if self.controlnet_pipeline is None:
logger.error("无法生成3D试穿:ControlNet 模型未加载")
return self.create_placeholder_image(512, 768)
if reference_image is not None:
prompt = f"{prompt}, based on reference design"
result = self.controlnet_pipeline(
prompt=prompt,
image=image,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
)
return result.images[0]
def create_placeholder_image(self, width, height):
"""创建占位图像"""
color = (random.randint(120, 200), random.randint(120, 200), random.randint(120, 200))
return Image.new('RGB', (width, height), color=color)
def cleanup(self):
"""释放模型占用显存和缓存"""
logger.info("释放模型占用显存和缓存...")
try:
del self.caption_model
del self.clip_model
del self.sd_pipeline
del self.controlnet
del self.controlnet_pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("显存和缓存清理完成")
except Exception as e:
logger.error(f"清理显存失败: {e}")