newGPU / models /model_manager.py
Humphreykowl's picture
Update models/model_manager.py
2efd7f8 verified
raw
history blame
7.53 kB
# models/model_manager.py
import torch
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel
from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, EulerAncestralDiscreteScheduler
import os
import logging
import time
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelManager:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"使用设备: {self.device}")
# 模型配置(更新了 SD 模型路径)
self.model_config = {
"caption_model": "Salesforce/blip-image-captioning-base",
"clip_model": "openai/clip-vit-base-patch32",
"sd_model": "runwayml/stable-diffusion-v1-5", # 这里用原版,可替换为镜像
"controlnet_model": "lllyasviel/sd-controlnet-openpose"
}
# 模型容器
self.caption_processor = None
self.caption_model = None
self.clip_processor = None
self.clip_model = None
self.sd_pipeline = None
self.controlnet = None
self.controlnet_pipeline = None
# 预加载所有模型
self.load_all_models()
def load_all_models(self):
self.load_caption_model()
self.load_clip_model()
self.load_sd_pipeline()
self.load_controlnet_pipeline()
def load_caption_model(self):
try:
logger.info("加载 BLIP 图像描述模型...")
self.caption_processor = BlipProcessor.from_pretrained(
self.model_config["caption_model"],
cache_dir="/tmp/models"
)
self.caption_model = BlipForConditionalGeneration.from_pretrained(
self.model_config["caption_model"],
cache_dir="/tmp/models",
torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
).to(self.device)
logger.info("BLIP 模型加载完成")
except Exception as e:
logger.error(f"BLIP 模型加载失败: {e}")
def load_clip_model(self):
try:
logger.info("加载 CLIP 模型...")
self.clip_processor = CLIPProcessor.from_pretrained(
self.model_config["clip_model"],
cache_dir="/tmp/models"
)
self.clip_model = CLIPModel.from_pretrained(
self.model_config["clip_model"],
cache_dir="/tmp/models",
torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
).to(self.device)
logger.info("CLIP 模型加载完成")
except Exception as e:
logger.error(f"CLIP 模型加载失败: {e}")
def load_sd_pipeline(self):
try:
logger.info("加载 Stable Diffusion Pipeline...")
self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
self.model_config["sd_model"],
revision="fp16" if self.device=="cuda" else None,
torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
cache_dir="/tmp/models",
safety_checker=None # 可按需配置安全检查器
)
self.sd_pipeline = self.sd_pipeline.to(self.device)
# 用更高效的调度器
self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.sd_pipeline.scheduler.config)
logger.info("Stable Diffusion Pipeline 加载完成")
except Exception as e:
logger.error(f"Stable Diffusion Pipeline 加载失败: {e}")
def load_controlnet_pipeline(self):
try:
logger.info("加载 ControlNet 模型和 Pipeline...")
self.controlnet = ControlNetModel.from_pretrained(
self.model_config["controlnet_model"],
cache_dir="/tmp/models",
torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
).to(self.device)
self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
self.model_config["sd_model"],
controlnet=self.controlnet,
cache_dir="/tmp/models",
torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
safety_checker=None
).to(self.device)
self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.controlnet_pipeline.scheduler.config)
logger.info("ControlNet Pipeline 加载完成")
except Exception as e:
logger.error(f"ControlNet Pipeline 加载失败: {e}")
# 下面是真正调用模型的接口
def generate_caption(self, image):
if self.caption_model is None or self.caption_processor is None:
self.load_caption_model()
inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.caption_model.generate(**inputs, max_length=50)
caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True)
return caption
def analyze_style(self, image):
if self.clip_model is None or self.clip_processor is None:
self.load_clip_model()
inputs = self.clip_processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.clip_model.get_image_features(**inputs)
features = outputs.cpu().numpy()[0]
# 简单归一化(范例)
norm = features / (np.linalg.norm(features) + 1e-10)
style_score = { "clip_feature_vector": norm }
return style_score
def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512):
if self.sd_pipeline is None:
self.load_sd_pipeline()
# Stable Diffusion 生成图像
result = self.sd_pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
height=height,
width=width
)
return result.images[0]
def generate_controlnet_image(self, image, prompt, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0):
if self.controlnet_pipeline is None:
self.load_controlnet_pipeline()
# 输入的 image 应该是 PIL Image 格式的控制图(比如人体姿态图)
result = self.controlnet_pipeline(
prompt=prompt,
image=image,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
)
return result.images[0]
def cleanup(self):
logger.info("释放模型占用显存和缓存...")
try:
del self.caption_model
del self.caption_processor
del self.clip_model
del self.clip_processor
del self.sd_pipeline
del self.controlnet
del self.controlnet_pipeline
torch.cuda.empty_cache()
import gc
gc.collect()
logger.info("显存清理完成")
except Exception as e:
logger.error(f"清理显存失败: {e}")