newGPU / models /model_manager.py
Humphreykowl's picture
Update models/model_manager.py
6f9c5be verified
raw
history blame
9.87 kB
# 完整的 modal_manager.py (即之前的 model_manager.py 完整实现,路径改为 model/modal_manager.py 可直接替换)
# 包含三视图打板一致性、手稿风格生成、多角度 3D 试穿支持、显存优化等全部功能
import torch
from PIL import Image
import numpy as np
from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel
from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, EulerAncestralDiscreteScheduler
import os
import logging
import time
import random
import gc
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelManager:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"使用设备: {self.device}")
self.model_config = {
"caption_model": "Salesforce/blip-image-captioning-large",
"clip_model": "openai/clip-vit-large-patch14",
"sd_model": "runwayml/stable-diffusion-v1-5",
"controlnet_model": "lllyasviel/control_v11p_sd15_openpose"
}
self.caption_processor = None
self.caption_model = None
self.clip_processor = None
self.clip_model = None
self.sd_pipeline = None
self.controlnet = None
self.controlnet_pipeline = None
self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
self.enable_attention_slicing = True
self.enable_cpu_offload = False
try:
self.load_all_models()
except Exception as e:
logger.warning(f"加载模型时出错: {e}")
def optimize_memory_usage(self):
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def load_all_models(self):
self.optimize_memory_usage()
self.load_caption_model()
self.load_clip_model()
self.load_sd_pipeline()
self.load_controlnet_pipeline()
logger.info("所有模型加载完成")
def load_caption_model(self):
self.caption_processor = BlipProcessor.from_pretrained(self.model_config["caption_model"], cache_dir="/tmp/models")
self.caption_model = BlipForConditionalGeneration.from_pretrained(
self.model_config["caption_model"],
cache_dir="/tmp/models",
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True
).to(self.device)
self.caption_model.enable_attention_slicing()
self.caption_model.eval()
def load_clip_model(self):
self.clip_processor = CLIPProcessor.from_pretrained(self.model_config["clip_model"], cache_dir="/tmp/models")
self.clip_model = CLIPModel.from_pretrained(self.model_config["clip_model"], cache_dir="/tmp/models", torch_dtype=self.torch_dtype).to(self.device)
self.clip_model.eval()
def load_sd_pipeline(self):
self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
self.model_config["sd_model"],
torch_dtype=self.torch_dtype,
cache_dir="/tmp/models",
safety_checker=None,
requires_safety_checker=False,
use_safetensors=True,
low_cpu_mem_usage=True
).to(self.device)
self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.sd_pipeline.scheduler.config)
if self.enable_attention_slicing:
self.sd_pipeline.enable_attention_slicing()
try:
self.sd_pipeline.enable_xformers_memory_efficient_attention()
except Exception:
pass
self.sd_pipeline.enable_vae_slicing()
self.sd_pipeline.safety_checker = None
def load_controlnet_pipeline(self):
self.controlnet = ControlNetModel.from_pretrained(
self.model_config["controlnet_model"],
cache_dir="/tmp/models",
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True
).to(self.device)
self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
self.model_config["sd_model"],
controlnet=self.controlnet,
cache_dir="/tmp/models",
torch_dtype=self.torch_dtype,
safety_checker=None,
requires_safety_checker=False,
low_cpu_mem_usage=True
).to(self.device)
self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.controlnet_pipeline.scheduler.config)
if self.enable_attention_slicing:
self.controlnet_pipeline.enable_attention_slicing()
try:
self.controlnet_pipeline.enable_xformers_memory_efficient_attention()
except Exception:
pass
self.controlnet_pipeline.enable_vae_slicing()
@torch.no_grad()
def generate_caption(self, image):
if image.mode != 'RGB':
image = image.convert('RGB')
if image.width > 512 or image.height > 512:
image.thumbnail((512, 512), Image.Resampling.LANCZOS)
inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
outputs = self.caption_model.generate(**inputs, max_length=50, num_beams=4, temperature=0.7, do_sample=True)
caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True)
del inputs, outputs
if torch.cuda.is_available():
torch.cuda.empty_cache()
return caption
@torch.no_grad()
def analyze_style(self, image):
style_labels = [
"business formal suit professional attire",
"casual comfortable everyday wear",
"athletic sportswear activewear",
"fashion trendy modern stylish",
"vintage retro classic style",
"streetwear urban contemporary",
"elegant sophisticated refined"
]
style_names = ["商务正装", "休闲风", "运动风", "时尚潮流", "复古风", "街头风", "优雅风"]
if image.mode != 'RGB':
image = image.convert('RGB')
if image.width > 224 or image.height > 224:
image.thumbnail((224, 224), Image.Resampling.LANCZOS)
inputs = self.clip_processor(text=style_labels, images=image, return_tensors="pt", padding=True, truncation=True, max_length=77).to(self.device)
outputs = self.clip_model(**inputs)
probs = outputs.logits_per_image.softmax(dim=1).cpu().numpy()[0]
return {name: float(prob) for name, prob in zip(style_names, probs)}
@torch.no_grad()
def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512, seed=None):
if negative_prompt is None:
negative_prompt = "blurry, low quality, distorted, text, watermark, ugly, deformed"
width = (width // 8) * 8
height = (height // 8) * 8
gen = torch.Generator(device=self.device).manual_seed(int(seed)) if seed is not None else None
result = self.sd_pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
height=height,
width=width,
generator=gen
)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return result.images[0]
@torch.no_grad()
def generate_controlnet_image(self, image, prompt, reference_image=None, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0, angle=0, width=512, height=768):
if image.mode != 'RGB':
image = image.convert('RGB')
control_image = image.resize((512, 768), Image.Resampling.LANCZOS)
if negative_prompt is None:
negative_prompt = "blurry, distorted, low quality, unrealistic, extra limbs, deformed, bad anatomy, multiple people"
prompt_with_angle = f"{prompt}, view from {angle} degrees"
if reference_image is not None:
prompt_with_angle = f"{prompt_with_angle}, based on provided reference design"
gen = torch.Generator(device=self.device).manual_seed(int(time.time()) + int(angle))
result = self.controlnet_pipeline(
prompt=prompt_with_angle,
image=control_image,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
controlnet_conditioning_scale=1.0,
generator=gen
)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return result.images[0]
def create_placeholder_image(self, width, height):
color = random.choice([(220, 220, 220), (200, 220, 240), (240, 220, 200), (220, 240, 200)])
return Image.new('RGB', (width, height), color=color)
def cleanup(self):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
try:
torch.cuda.ipc_collect()
except Exception:
pass
def get_model_status(self):
status = {
"caption_model": self.caption_model is not None,
"clip_model": self.clip_model is not None,
"sd_pipeline": self.sd_pipeline is not None,
"controlnet_pipeline": self.controlnet_pipeline is not None,
"device": self.device
}
if torch.cuda.is_available():
status["gpu_memory"] = {
"allocated": f"{torch.cuda.memory_allocated() / 1024**3:.2f}GB",
"cached": f"{torch.cuda.memory_reserved() / 1024**3:.2f}GB"
}
return status