File size: 8,504 Bytes
57e54f6
 
 
 
 
 
 
 
412fd95
57e54f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2efd7f8
57e54f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9888744
57e54f6
2efd7f8
 
 
412fd95
2efd7f8
 
 
 
 
57e54f6
 
 
8ceafcb
57e54f6
 
412fd95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import torch
from PIL import Image
import numpy as np
from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel
from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, EulerAncestralDiscreteScheduler
import os
import logging
import time
import random  # 补充导入

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ModelManager:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info(f"使用设备: {self.device}")

        # 模型配置 - 使用更精细的3D模型
        self.model_config = {
            "caption_model": "Salesforce/blip-image-captioning-large",
            "clip_model": "openai/clip-vit-large-patch14",
            "sd_model": "runwayml/stable-diffusion-v1-5",
            "controlnet_model": "lllyasviel/control_v11p_sd15_openpose"
        }

        # 模型容器
        self.caption_processor = None
        self.caption_model = None
        self.clip_processor = None
        self.clip_model = None
        self.sd_pipeline = None
        self.controlnet = None
        self.controlnet_pipeline = None

        # 预加载所有模型
        self.load_all_models()

    def load_all_models(self):
        self.load_caption_model()
        self.load_clip_model()
        self.load_sd_pipeline()
        self.load_controlnet_pipeline()

    def load_caption_model(self):
        try:
            logger.info("加载 BLIP 图像描述模型...")
            self.caption_processor = BlipProcessor.from_pretrained(
                self.model_config["caption_model"],
                cache_dir="/tmp/models"
            )
            self.caption_model = BlipForConditionalGeneration.from_pretrained(
                self.model_config["caption_model"],
                cache_dir="/tmp/models",
                torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
            ).to(self.device)
            logger.info("BLIP 模型加载完成")
        except Exception as e:
            logger.error(f"BLIP 模型加载失败: {e}")

    def load_clip_model(self):
        try:
            logger.info("加载 CLIP 模型...")
            self.clip_processor = CLIPProcessor.from_pretrained(
                self.model_config["clip_model"],
                cache_dir="/tmp/models"
            )
            self.clip_model = CLIPModel.from_pretrained(
                self.model_config["clip_model"],
                cache_dir="/tmp/models",
                torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
            ).to(self.device)
            logger.info("CLIP 模型加载完成")
        except Exception as e:
            logger.error(f"CLIP 模型加载失败: {e}")

    def load_sd_pipeline(self):
        try:
            logger.info("加载 Stable Diffusion Pipeline...")
            self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
                self.model_config["sd_model"],
                torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
                cache_dir="/tmp/models",
                safety_checker=None,
                use_safetensors=True
            )
            self.sd_pipeline = self.sd_pipeline.to(self.device)
            self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.sd_pipeline.scheduler.config)
            logger.info("Stable Diffusion Pipeline 加载完成")
        except Exception as e:
            logger.error(f"Stable Diffusion Pipeline 加载失败: {e}")

    def load_controlnet_pipeline(self):
        try:
            logger.info("加载 ControlNet 模型和 Pipeline...")
            self.controlnet = ControlNetModel.from_pretrained(
                self.model_config["controlnet_model"],
                cache_dir="/tmp/models",
                torch_dtype=torch.float16 if self.device=="cuda" else torch.float32
            ).to(self.device)

            self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
                self.model_config["sd_model"],
                controlnet=self.controlnet,
                cache_dir="/tmp/models",
                torch_dtype=torch.float16 if self.device=="cuda" else torch.float32,
                safety_checker=None
            ).to(self.device)

            self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.controlnet_pipeline.scheduler.config)
            logger.info("ControlNet Pipeline 加载完成")
        except Exception as e:
            logger.error(f"ControlNet Pipeline 加载失败: {e}")

    # 下面是真正调用模型的接口

    def generate_caption(self, image):
        """使用BLIP模型生成图像描述"""
        if self.caption_model is None or self.caption_processor is None:
            self.load_caption_model()

        inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.caption_model.generate(**inputs, max_length=50)
        caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True)
        return caption

    def analyze_style(self, image):
        """使用CLIP模型分析服装风格"""
        if self.clip_model is None or self.clip_processor is None:
            self.load_clip_model()
        
        styles = ["商务正装", "休闲风", "运动风", "时尚潮流", "复古风", "街头风", "优雅风"]
        
        inputs = self.clip_processor(
            text=styles,
            images=image,
            return_tensors="pt",
            padding=True
        ).to(self.device)
        
        with torch.no_grad():
            outputs = self.clip_model(**inputs)
            logits_per_image = outputs.logits_per_image
            probs = logits_per_image.softmax(dim=1).cpu().numpy()[0]
        
        style_scores = {style: float(prob) for style, prob in zip(styles, probs)}
        return style_scores

    def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512):
        """使用Stable Diffusion生成设计图像"""
        if self.sd_pipeline is None:
            self.load_sd_pipeline()
            if self.sd_pipeline is None:
                logger.error("无法生成图像:Stable Diffusion 模型未加载")
                return self.create_placeholder_image(width, height)

        result = self.sd_pipeline(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            height=height,
            width=width
        )
        return result.images[0]

    def generate_controlnet_image(self, image, prompt, reference_image=None, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0):
        """使用ControlNet生成3D试穿效果 - 更精细的模型"""
        if self.controlnet_pipeline is None:
            self.load_controlnet_pipeline()
            if self.controlnet_pipeline is None:
                logger.error("无法生成3D试穿:ControlNet 模型未加载")
                return self.create_placeholder_image(512, 768)
        
        if reference_image is not None:
            prompt = f"{prompt}, based on reference design"
        
        result = self.controlnet_pipeline(
            prompt=prompt,
            image=image,
            negative_prompt=negative_prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
        )
        return result.images[0]

    def create_placeholder_image(self, width, height):
        """创建占位图像"""
        color = (random.randint(120, 200), random.randint(120, 200), random.randint(120, 200))
        return Image.new('RGB', (width, height), color=color)

    def cleanup(self):
        """释放模型占用显存和缓存"""
        logger.info("释放模型占用显存和缓存...")
        try:
            del self.caption_model
            del self.clip_model
            del self.sd_pipeline
            del self.controlnet
            del self.controlnet_pipeline
            
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            logger.info("显存和缓存清理完成")
        except Exception as e:
            logger.error(f"清理显存失败: {e}")