newGPU / models /model_manager.py
Humphreykowl's picture
Update models/model_manager.py
49dfa5d verified
raw
history blame
17.1 kB
# models/model_manager.py
import torch
from PIL import Image
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
CLIPProcessor,
CLIPModel
)
from diffusers import (
StableDiffusionPipeline,
StableDiffusionControlNetPipeline,
ControlNetModel,
EulerAncestralDiscreteScheduler
)
import numpy as np
import gc
import os
import logging
import time
from typing import Optional, Dict, List, Tuple
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelManager:
def __init__(self):
# 自动检测设备
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"使用设备: {self.device}")
# 初始化模型为空
self.caption_model = None
self.caption_processor = None
self.clip_model = None
self.clip_processor = None
self.sd_pipeline = None
self.controlnet_pipeline = None
self.controlnet = None
# 模型配置 - 使用较小的模型变体以适应 Space 环境
self.model_config = {
"caption_model": "Salesforce/blip-image-captioning-base", # 基础版节省内存
"clip_model": "openai/clip-vit-base-patch32", # 基础版CLIP
"sd_model": "stabilityai/stable-diffusion-2-1-base", # SD 2.1基础版
"controlnet_model": "lllyasviel/sd-controlnet-openpose" # 姿势控制模型
}
# 创建缓存目录 - 使用Space的临时目录
self.cache_dir = "/tmp/models"
os.makedirs(self.cache_dir, exist_ok=True)
logger.info(f"模型缓存目录: {self.cache_dir}")
# 加载统计
self.load_times = {}
self.last_used = {}
def load_caption_model(self):
"""加载图像描述模型"""
if self.caption_model is None:
start_time = time.time()
logger.info("正在加载图像描述模型...")
try:
self.caption_processor = BlipProcessor.from_pretrained(
self.model_config["caption_model"],
cache_dir=self.cache_dir
)
self.caption_model = BlipForConditionalGeneration.from_pretrained(
self.model_config["caption_model"],
cache_dir=self.cache_dir,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
# 模型优化
if self.device == "cuda":
self.caption_model = self.caption_model.half()
logger.info("图像描述模型加载完成")
self.load_times["caption"] = time.time() - start_time
self.last_used["caption"] = time.time()
except Exception as e:
logger.error(f"加载描述模型失败: {str(e)}")
# 尝试回退到更小的模型
self.model_config["caption_model"] = "Salesforce/blip-image-captioning-base"
self.load_caption_model()
def load_clip_model(self):
"""加载CLIP模型用于风格分析"""
if self.clip_model is None:
start_time = time.time()
logger.info("正在加载CLIP模型...")
try:
self.clip_processor = CLIPProcessor.from_pretrained(
self.model_config["clip_model"],
cache_dir=self.cache_dir
)
self.clip_model = CLIPModel.from_pretrained(
self.model_config["clip_model"],
cache_dir=self.cache_dir,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
# 模型优化
if self.device == "cuda":
self.clip_model = self.clip_model.half()
logger.info("CLIP模型加载完成")
self.load_times["clip"] = time.time() - start_time
self.last_used["clip"] = time.time()
except Exception as e:
logger.error(f"加载CLIP模型失败: {str(e)}")
def load_sd_pipeline(self):
"""加载Stable Diffusion生成管道"""
if self.sd_pipeline is None:
start_time = time.time()
logger.info("正在加载Stable Diffusion模型...")
# 根据可用内存选择模型变体
if self.device == "cuda" and torch.cuda.get_device_properties(0).total_memory < 10 * 1024**3:
logger.info("检测到有限GPU内存,使用更小的SD模型")
self.model_config["sd_model"] = "runwayml/stable-diffusion-v1-5"
try:
self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
self.model_config["sd_model"],
cache_dir=self.cache_dir,
safety_checker=None, # 禁用安全检查以节省内存
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
# 优化性能
if self.device == "cuda":
try:
# 启用内存高效注意力
self.sd_pipeline.enable_xformers_memory_efficient_attention()
except:
logger.warning("无法启用xformers,使用回退方案")
# 启用注意力切片
self.sd_pipeline.enable_attention_slicing()
logger.info("Stable Diffusion模型加载完成")
self.load_times["sd"] = time.time() - start_time
self.last_used["sd"] = time.time()
except Exception as e:
logger.error(f"加载SD模型失败: {str(e)}")
# 尝试回退到更小的模型
self.model_config["sd_model"] = "runwayml/stable-diffusion-v1-5"
self.load_sd_pipeline()
def load_controlnet_pipeline(self):
"""加载ControlNet管道用于3D试穿"""
if self.controlnet_pipeline is None:
start_time = time.time()
logger.info("正在加载ControlNet模型...")
try:
# 先加载ControlNet模型
self.controlnet = ControlNetModel.from_pretrained(
self.model_config["controlnet_model"],
cache_dir=self.cache_dir,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
)
# 然后创建ControlNet管道
self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
self.model_config["sd_model"],
controlnet=self.controlnet,
cache_dir=self.cache_dir,
safety_checker=None,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
# 设置调度器
self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
self.controlnet_pipeline.scheduler.config
)
# 优化性能
if self.device == "cuda":
try:
self.controlnet_pipeline.enable_xformers_memory_efficient_attention()
except:
logger.warning("无法为ControlNet启用xformers")
self.controlnet_pipeline.enable_attention_slicing()
logger.info("ControlNet模型加载完成")
self.load_times["controlnet"] = time.time() - start_time
self.last_used["controlnet"] = time.time()
except Exception as e:
logger.error(f"加载ControlNet模型失败: {str(e)}")
def generate_caption(self, image: Image.Image) -> str:
"""为图像生成描述性标题"""
try:
self.load_caption_model()
self.last_used["caption"] = time.time()
# 准备输入
inputs = self.caption_processor(
images=image,
return_tensors="pt"
).to(self.device, torch.float16 if self.device == "cuda" else torch.float32)
# 生成标题
output = self.caption_model.generate(**inputs, max_length=50)
caption = self.caption_processor.decode(output[0], skip_special_tokens=True)
logger.info(f"生成的标题: {caption}")
return caption
except Exception as e:
logger.error(f"生成标题失败: {str(e)}")
# 返回默认标题
return "时尚服装设计"
def analyze_style(self, image: Image.Image) -> Dict[str, float]:
"""使用CLIP分析图像风格"""
try:
self.load_clip_model()
self.last_used["clip"] = time.time()
# 定义风格类别
style_labels = [
"商务正装", "休闲风", "运动风", "时尚潮流",
"复古风", "街头风", "优雅风", "民族风"
]
# 准备输入
inputs = self.clip_processor(
text=style_labels,
images=image,
return_tensors="pt",
padding=True
).to(self.device)
# 获取预测
outputs = self.clip_model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1).detach().cpu().numpy()[0]
# 获取前3个风格
top3_idx = np.argsort(probs)[-3:][::-1]
top_styles = {
style_labels[i]: float(probs[i]) for i in top3_idx
}
logger.info(f"风格分析结果: {top_styles}")
return top_styles
except Exception as e:
logger.error(f"风格分析失败: {str(e)}")
# 返回默认风格
return {"休闲风": 0.8, "时尚潮流": 0.7}
def generate_image(
self,
prompt: str,
negative_prompt: str = "",
num_inference_steps: int = 30,
guidance_scale: float = 7.5,
height: int = 512,
width: int = 512
) -> Image.Image:
"""根据提示生成设计图像"""
try:
self.load_sd_pipeline()
self.last_used["sd"] = time.time()
# 生成图像
with torch.autocast("cuda" if self.device == "cuda" else "cpu"):
image = self.sd_pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
height=height,
width=width
).images[0]
logger.info(f"成功生成设计图像: {prompt[:50]}...")
return image
except Exception as e:
logger.error(f"生成设计图像失败: {str(e)}")
# 创建占位图像
return Image.new('RGB', (512, 512), color=(220, 220, 220))
def generate_controlnet_image(
self,
image: Image.Image,
prompt: str,
negative_prompt: str = "",
num_inference_steps: int = 35,
guidance_scale: float = 8.0
) -> Image.Image:
"""使用ControlNet生成3D试穿图像"""
try:
self.load_controlnet_pipeline()
self.last_used["controlnet"] = time.time()
# 生成图像
with torch.autocast("cuda" if self.device == "cuda" else "cpu"):
image = self.controlnet_pipeline(
prompt=prompt,
image=image,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
controlnet_conditioning_scale=0.8
).images[0]
logger.info(f"成功生成3D试穿图像")
return image
except Exception as e:
logger.error(f"生成3D试穿图像失败: {str(e)}")
# 回退到普通SD模型
return self.generate_image(
prompt,
negative_prompt,
num_inference_steps
)
def unload_model(self, model_type: str):
"""卸载指定类型的模型以释放内存"""
logger.info(f"卸载模型: {model_type}")
if model_type == "caption" and self.caption_model is not None:
del self.caption_model
del self.caption_processor
self.caption_model = None
self.caption_processor = None
logger.info("卸载图像描述模型")
elif model_type == "clip" and self.clip_model is not None:
del self.clip_model
del self.clip_processor
self.clip_model = None
self.clip_processor = None
logger.info("卸载CLIP模型")
elif model_type == "sd" and self.sd_pipeline is not None:
del self.sd_pipeline
self.sd_pipeline = None
logger.info("卸载Stable Diffusion模型")
elif model_type == "controlnet" and self.controlnet_pipeline is not None:
del self.controlnet_pipeline
del self.controlnet
self.controlnet_pipeline = None
self.controlnet = None
logger.info("卸载ControlNet模型")
# 清理内存
self.cleanup_memory()
def cleanup(self):
"""清理所有模型释放内存"""
logger.info("清理所有模型释放内存...")
# 释放所有模型
if self.caption_model is not None:
del self.caption_model
if self.caption_processor is not None:
del self.caption_processor
if self.clip_model is not None:
del self.clip_model
if self.clip_processor is not None:
del self.clip_processor
if self.sd_pipeline is not None:
del self.sd_pipeline
if self.controlnet_pipeline is not None:
del self.controlnet_pipeline
if self.controlnet is not None:
del self.controlnet
# 重置引用
self.caption_model = None
self.caption_processor = None
self.clip_model = None
self.clip_processor = None
self.sd_pipeline = None
self.controlnet_pipeline = None
self.controlnet = None
# 清理内存
self.cleanup_memory()
logger.info("内存清理完成")
def cleanup_memory(self):
"""执行内存清理操作"""
# 清理CUDA缓存
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 执行垃圾回收
gc.collect()
def get_memory_usage(self) -> Dict[str, float]:
"""获取当前内存使用情况"""
mem_info = {}
if torch.cuda.is_available():
mem_info["gpu_total"] = torch.cuda.get_device_properties(0).total_memory / (1024**3)
mem_info["gpu_used"] = torch.cuda.memory_allocated() / (1024**3)
mem_info["gpu_free"] = mem_info["gpu_total"] - mem_info["gpu_used"]
return mem_info
def get_model_status(self) -> Dict[str, str]:
"""获取模型加载状态"""
status = {
"caption_model": "已加载" if self.caption_model else "未加载",
"clip_model": "已加载" if self.clip_model else "未加载",
"sd_model": "已加载" if self.sd_pipeline else "未加载",
"controlnet_model": "已加载" if self.controlnet_pipeline else "未加载"
}
# 添加加载时间信息
for model in ["caption", "clip", "sd", "controlnet"]:
if model in self.load_times:
status[f"{model}_load_time"] = f"{self.load_times[model]:.2f}秒"
if model in self.last_used:
mins_ago = (time.time() - self.last_used[model]) / 60
status[f"{model}_last_used"] = f"{mins_ago:.1f}分钟前"
return status
def __del__(self):
"""析构函数确保资源释放"""
self.cleanup()