File size: 9,869 Bytes
6f9c5be
 
 
57e54f6
 
 
 
 
 
 
 
422bb60
 
57e54f6
 
 
 
 
 
 
 
 
 
 
6f9c5be
57e54f6
 
 
 
 
 
 
 
 
 
 
6f9c5be
422bb60
 
6f9c5be
 
 
 
 
 
57e54f6
422bb60
 
 
 
 
 
57e54f6
422bb60
6f9c5be
 
 
 
 
57e54f6
 
6f9c5be
 
 
 
 
 
 
 
 
57e54f6
 
6f9c5be
 
 
57e54f6
 
6f9c5be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57e54f6
 
6f9c5be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57e54f6
6f9c5be
57e54f6
6f9c5be
 
 
 
 
 
 
 
 
 
 
57e54f6
422bb60
57e54f6
6f9c5be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f326b4
422bb60
6f9c5be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cc48ff
 
6f9c5be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cc48ff
 
6f9c5be
9cc48ff
 
 
6f9c5be
 
 
 
9cc48ff
6f9c5be
 
9cc48ff
 
 
 
 
 
 
 
 
 
 
 
6f9c5be
9cc48ff
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
# 完整的 modal_manager.py (即之前的 model_manager.py 完整实现,路径改为 model/modal_manager.py 可直接替换)
# 包含三视图打板一致性、手稿风格生成、多角度 3D 试穿支持、显存优化等全部功能

import torch
from PIL import Image
import numpy as np
from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel
from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, EulerAncestralDiscreteScheduler
import os
import logging
import time
import random
import gc

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ModelManager:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info(f"使用设备: {self.device}")

        self.model_config = {
            "caption_model": "Salesforce/blip-image-captioning-large",
            "clip_model": "openai/clip-vit-large-patch14",
            "sd_model": "runwayml/stable-diffusion-v1-5",
            "controlnet_model": "lllyasviel/control_v11p_sd15_openpose"
        }

        self.caption_processor = None
        self.caption_model = None
        self.clip_processor = None
        self.clip_model = None
        self.sd_pipeline = None
        self.controlnet = None
        self.controlnet_pipeline = None

        self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
        self.enable_attention_slicing = True
        self.enable_cpu_offload = False

        try:
            self.load_all_models()
        except Exception as e:
            logger.warning(f"加载模型时出错: {e}")

    def optimize_memory_usage(self):
        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True

    def load_all_models(self):
        self.optimize_memory_usage()
        self.load_caption_model()
        self.load_clip_model()
        self.load_sd_pipeline()
        self.load_controlnet_pipeline()
        logger.info("所有模型加载完成")

    def load_caption_model(self):
        self.caption_processor = BlipProcessor.from_pretrained(self.model_config["caption_model"], cache_dir="/tmp/models")
        self.caption_model = BlipForConditionalGeneration.from_pretrained(
            self.model_config["caption_model"],
            cache_dir="/tmp/models",
            torch_dtype=self.torch_dtype,
            low_cpu_mem_usage=True
        ).to(self.device)
        self.caption_model.enable_attention_slicing()
        self.caption_model.eval()

    def load_clip_model(self):
        self.clip_processor = CLIPProcessor.from_pretrained(self.model_config["clip_model"], cache_dir="/tmp/models")
        self.clip_model = CLIPModel.from_pretrained(self.model_config["clip_model"], cache_dir="/tmp/models", torch_dtype=self.torch_dtype).to(self.device)
        self.clip_model.eval()

    def load_sd_pipeline(self):
        self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
            self.model_config["sd_model"],
            torch_dtype=self.torch_dtype,
            cache_dir="/tmp/models",
            safety_checker=None,
            requires_safety_checker=False,
            use_safetensors=True,
            low_cpu_mem_usage=True
        ).to(self.device)
        self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.sd_pipeline.scheduler.config)
        if self.enable_attention_slicing:
            self.sd_pipeline.enable_attention_slicing()
        try:
            self.sd_pipeline.enable_xformers_memory_efficient_attention()
        except Exception:
            pass
        self.sd_pipeline.enable_vae_slicing()
        self.sd_pipeline.safety_checker = None

    def load_controlnet_pipeline(self):
        self.controlnet = ControlNetModel.from_pretrained(
            self.model_config["controlnet_model"],
            cache_dir="/tmp/models",
            torch_dtype=self.torch_dtype,
            low_cpu_mem_usage=True
        ).to(self.device)
        self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
            self.model_config["sd_model"],
            controlnet=self.controlnet,
            cache_dir="/tmp/models",
            torch_dtype=self.torch_dtype,
            safety_checker=None,
            requires_safety_checker=False,
            low_cpu_mem_usage=True
        ).to(self.device)
        self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.controlnet_pipeline.scheduler.config)
        if self.enable_attention_slicing:
            self.controlnet_pipeline.enable_attention_slicing()
        try:
            self.controlnet_pipeline.enable_xformers_memory_efficient_attention()
        except Exception:
            pass
        self.controlnet_pipeline.enable_vae_slicing()

    @torch.no_grad()
    def generate_caption(self, image):
        if image.mode != 'RGB':
            image = image.convert('RGB')
        if image.width > 512 or image.height > 512:
            image.thumbnail((512, 512), Image.Resampling.LANCZOS)
        inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
        outputs = self.caption_model.generate(**inputs, max_length=50, num_beams=4, temperature=0.7, do_sample=True)
        caption = self.caption_processor.decode(outputs[0], skip_special_tokens=True)
        del inputs, outputs
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return caption

    @torch.no_grad()
    def analyze_style(self, image):
        style_labels = [
            "business formal suit professional attire",
            "casual comfortable everyday wear",
            "athletic sportswear activewear",
            "fashion trendy modern stylish",
            "vintage retro classic style",
            "streetwear urban contemporary",
            "elegant sophisticated refined"
        ]
        style_names = ["商务正装", "休闲风", "运动风", "时尚潮流", "复古风", "街头风", "优雅风"]
        if image.mode != 'RGB':
            image = image.convert('RGB')
        if image.width > 224 or image.height > 224:
            image.thumbnail((224, 224), Image.Resampling.LANCZOS)
        inputs = self.clip_processor(text=style_labels, images=image, return_tensors="pt", padding=True, truncation=True, max_length=77).to(self.device)
        outputs = self.clip_model(**inputs)
        probs = outputs.logits_per_image.softmax(dim=1).cpu().numpy()[0]
        return {name: float(prob) for name, prob in zip(style_names, probs)}

    @torch.no_grad()
    def generate_image(self, prompt, negative_prompt=None, num_inference_steps=25, guidance_scale=7.5, width=512, height=512, seed=None):
        if negative_prompt is None:
            negative_prompt = "blurry, low quality, distorted, text, watermark, ugly, deformed"
        width = (width // 8) * 8
        height = (height // 8) * 8
        gen = torch.Generator(device=self.device).manual_seed(int(seed)) if seed is not None else None
        result = self.sd_pipeline(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            height=height,
            width=width,
            generator=gen
        )
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return result.images[0]

    @torch.no_grad()
    def generate_controlnet_image(self, image, prompt, reference_image=None, negative_prompt=None, num_inference_steps=30, guidance_scale=8.0, angle=0, width=512, height=768):
        if image.mode != 'RGB':
            image = image.convert('RGB')
        control_image = image.resize((512, 768), Image.Resampling.LANCZOS)
        if negative_prompt is None:
            negative_prompt = "blurry, distorted, low quality, unrealistic, extra limbs, deformed, bad anatomy, multiple people"
        prompt_with_angle = f"{prompt}, view from {angle} degrees"
        if reference_image is not None:
            prompt_with_angle = f"{prompt_with_angle}, based on provided reference design"
        gen = torch.Generator(device=self.device).manual_seed(int(time.time()) + int(angle))
        result = self.controlnet_pipeline(
            prompt=prompt_with_angle,
            image=control_image,
            negative_prompt=negative_prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            controlnet_conditioning_scale=1.0,
            generator=gen
        )
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return result.images[0]

    def create_placeholder_image(self, width, height):
        color = random.choice([(220, 220, 220), (200, 220, 240), (240, 220, 200), (220, 240, 200)])
        return Image.new('RGB', (width, height), color=color)

    def cleanup(self):
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            try:
                torch.cuda.ipc_collect()
            except Exception:
                pass

    def get_model_status(self):
        status = {
            "caption_model": self.caption_model is not None,
            "clip_model": self.clip_model is not None,
            "sd_pipeline": self.sd_pipeline is not None,
            "controlnet_pipeline": self.controlnet_pipeline is not None,
            "device": self.device
        }
        if torch.cuda.is_available():
            status["gpu_memory"] = {
                "allocated": f"{torch.cuda.memory_allocated() / 1024**3:.2f}GB",
                "cached": f"{torch.cuda.memory_reserved() / 1024**3:.2f}GB"
            }
        return status