Humphreykowl commited on
Commit
49dfa5d
·
verified ·
1 Parent(s): 479bc48

Update models/model_manager.py

Browse files
Files changed (1) hide show
  1. models/model_manager.py +411 -102
models/model_manager.py CHANGED
@@ -1,130 +1,439 @@
1
  # models/model_manager.py
2
- from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
3
- from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline
4
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  class ModelManager:
7
  def __init__(self):
 
8
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
9
- self.init_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- def init_models(self):
12
- print("正在加载模型...")
13
-
14
- # 修复1: 使用兼容的 BLIP 模型
15
- print("加载图像理解模型...")
16
- self.blip_processor = BlipProcessor.from_pretrained(
17
- "Salesforce/blip-image-captioning-base",
18
- # 添加兼容性参数
19
- )
20
- self.blip_model = BlipForConditionalGeneration.from_pretrained(
21
- "Salesforce/blip-image-captioning-base",
22
- ).to(self.device)
23
-
24
- # 修复2: 文本生成模型 - 添加错误处理
25
- print("加载文本生成模型...")
26
  try:
27
- self.text_generator = pipeline(
28
- "text-generation",
29
- model="microsoft/DialoGPT-medium",
30
- device=0 if self.device == "cuda" else -1
31
- )
32
- except Exception as e:
33
- print(f"DialoGPT 加载失败,使用备选模型: {e}")
34
- self.text_generator = pipeline(
35
- "text-generation",
36
- model="gpt2",
37
- device=0 if self.device == "cuda" else -1
38
- )
 
 
 
39
 
40
- # 修复3: Stable Diffusion 模型 - 添加内存优化
41
- print("加载 Stable Diffusion 模型...")
42
- self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
43
- "runwayml/stable-diffusion-v1-5",
44
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
45
- use_safetensors=True,
46
- variant="fp16" if self.device == "cuda" else None
47
- )
48
-
49
- # 内存优化
50
- if self.device == "cuda":
51
- self.sd_pipeline.enable_model_cpu_offload()
52
- self.sd_pipeline.enable_xformers_memory_efficient_attention()
53
- else:
54
- self.sd_pipeline = self.sd_pipeline.to(self.device)
55
-
56
- # 修复4: ControlNet 模型 - 添加错误处理
57
- print("加载 ControlNet 模型...")
58
  try:
59
- self.controlnet = ControlNetModel.from_pretrained(
60
- "lllyasviel/sd-controlnet-openpose",
61
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
62
- use_safetensors=True
63
- )
64
 
65
- self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
66
- "runwayml/stable-diffusion-v1-5",
67
- controlnet=self.controlnet,
68
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
69
- use_safetensors=True,
70
- variant="fp16" if self.device == "cuda" else None
71
- )
72
 
73
- if self.device == "cuda":
74
- self.controlnet_pipeline.enable_model_cpu_offload()
75
- self.controlnet_pipeline.enable_xformers_memory_efficient_attention()
76
- else:
77
- self.controlnet_pipeline = self.controlnet_pipeline.to(self.device)
78
-
79
- except Exception as e:
80
- print(f"ControlNet 加载失败: {e}")
81
- self.controlnet = None
82
- self.controlnet_pipeline = None
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- print("所有模型加载完成!")
85
-
86
- def generate_caption(self, image):
87
- """生成图像描述"""
88
- inputs = self.blip_processor(image, return_tensors="pt").to(self.device)
89
- with torch.no_grad():
90
- out = self.blip_model.generate(**inputs, max_length=50)
91
- return self.blip_processor.decode(out[0], skip_special_tokens=True)
92
 
93
- def generate_text(self, prompt, max_length=100):
94
- """生成文本"""
 
 
 
 
 
 
 
 
95
  try:
96
- result = self.text_generator(
97
- prompt,
98
- max_length=max_length,
99
- num_return_sequences=1,
100
- temperature=0.7,
101
- do_sample=True,
102
- pad_token_id=self.text_generator.tokenizer.eos_token_id
103
- )
104
- return result[0]['generated_text']
 
 
 
 
 
 
 
 
105
  except Exception as e:
106
- print(f"文本生成错误: {e}")
107
- return f"生成失败: {str(e)}"
 
108
 
109
- def generate_image(self, prompt, negative_prompt="", num_inference_steps=20):
110
- """生成图像"""
 
 
 
 
 
 
 
111
  try:
112
- with torch.autocast(self.device):
113
- image = self.sd_pipeline(
 
 
 
 
114
  prompt=prompt,
 
115
  negative_prompt=negative_prompt,
116
  num_inference_steps=num_inference_steps,
117
- guidance_scale=7.5
 
118
  ).images[0]
 
 
119
  return image
 
120
  except Exception as e:
121
- print(f"图像生成错误: {e}")
122
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  def cleanup(self):
125
- """清理 GPU 内存"""
126
- if hasattr(self, 'sd_pipeline'):
 
 
 
 
 
 
 
 
 
 
 
127
  del self.sd_pipeline
128
- if hasattr(self, 'controlnet_pipeline'):
129
  del self.controlnet_pipeline
130
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # models/model_manager.py
 
 
2
  import torch
3
+ from PIL import Image
4
+ from transformers import (
5
+ BlipProcessor,
6
+ BlipForConditionalGeneration,
7
+ CLIPProcessor,
8
+ CLIPModel
9
+ )
10
+ from diffusers import (
11
+ StableDiffusionPipeline,
12
+ StableDiffusionControlNetPipeline,
13
+ ControlNetModel,
14
+ EulerAncestralDiscreteScheduler
15
+ )
16
+ import numpy as np
17
+ import gc
18
+ import os
19
+ import logging
20
+ import time
21
+ from typing import Optional, Dict, List, Tuple
22
+
23
+ # 设置日志
24
+ logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
26
 
27
  class ModelManager:
28
  def __init__(self):
29
+ # 自动检测设备
30
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ logger.info(f"使用设备: {self.device}")
32
+
33
+ # 初始化模型为空
34
+ self.caption_model = None
35
+ self.caption_processor = None
36
+ self.clip_model = None
37
+ self.clip_processor = None
38
+ self.sd_pipeline = None
39
+ self.controlnet_pipeline = None
40
+ self.controlnet = None
41
+
42
+ # 模型配置 - 使用较小的模型变体以适应 Space 环境
43
+ self.model_config = {
44
+ "caption_model": "Salesforce/blip-image-captioning-base", # 基础版节省内存
45
+ "clip_model": "openai/clip-vit-base-patch32", # 基础版CLIP
46
+ "sd_model": "stabilityai/stable-diffusion-2-1-base", # SD 2.1基础版
47
+ "controlnet_model": "lllyasviel/sd-controlnet-openpose" # 姿势控制模型
48
+ }
49
+
50
+ # 创建缓存目录 - 使用Space的临时目录
51
+ self.cache_dir = "/tmp/models"
52
+ os.makedirs(self.cache_dir, exist_ok=True)
53
+ logger.info(f"模型缓存目录: {self.cache_dir}")
54
+
55
+ # 加载统计
56
+ self.load_times = {}
57
+ self.last_used = {}
58
+
59
+ def load_caption_model(self):
60
+ """加载图像描述模型"""
61
+ if self.caption_model is None:
62
+ start_time = time.time()
63
+ logger.info("正在加载图像描述模型...")
64
+
65
+ try:
66
+ self.caption_processor = BlipProcessor.from_pretrained(
67
+ self.model_config["caption_model"],
68
+ cache_dir=self.cache_dir
69
+ )
70
+
71
+ self.caption_model = BlipForConditionalGeneration.from_pretrained(
72
+ self.model_config["caption_model"],
73
+ cache_dir=self.cache_dir,
74
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
75
+ ).to(self.device)
76
+
77
+ # 模型优化
78
+ if self.device == "cuda":
79
+ self.caption_model = self.caption_model.half()
80
+
81
+ logger.info("图像描述模型加载完成")
82
+ self.load_times["caption"] = time.time() - start_time
83
+ self.last_used["caption"] = time.time()
84
+ except Exception as e:
85
+ logger.error(f"加载描述模型失败: {str(e)}")
86
+ # 尝试回退到更小的模型
87
+ self.model_config["caption_model"] = "Salesforce/blip-image-captioning-base"
88
+ self.load_caption_model()
89
+
90
+ def load_clip_model(self):
91
+ """加载CLIP模型用于风格分析"""
92
+ if self.clip_model is None:
93
+ start_time = time.time()
94
+ logger.info("正在加载CLIP模型...")
95
+
96
+ try:
97
+ self.clip_processor = CLIPProcessor.from_pretrained(
98
+ self.model_config["clip_model"],
99
+ cache_dir=self.cache_dir
100
+ )
101
+
102
+ self.clip_model = CLIPModel.from_pretrained(
103
+ self.model_config["clip_model"],
104
+ cache_dir=self.cache_dir,
105
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
106
+ ).to(self.device)
107
+
108
+ # 模型优化
109
+ if self.device == "cuda":
110
+ self.clip_model = self.clip_model.half()
111
+
112
+ logger.info("CLIP模型加载完成")
113
+ self.load_times["clip"] = time.time() - start_time
114
+ self.last_used["clip"] = time.time()
115
+ except Exception as e:
116
+ logger.error(f"加载CLIP模型失败: {str(e)}")
117
+
118
+ def load_sd_pipeline(self):
119
+ """加载Stable Diffusion生成管道"""
120
+ if self.sd_pipeline is None:
121
+ start_time = time.time()
122
+ logger.info("正在加载Stable Diffusion模型...")
123
+
124
+ # 根据可用内存选择模型变体
125
+ if self.device == "cuda" and torch.cuda.get_device_properties(0).total_memory < 10 * 1024**3:
126
+ logger.info("检测到有限GPU内存,使用更小的SD模型")
127
+ self.model_config["sd_model"] = "runwayml/stable-diffusion-v1-5"
128
+
129
+ try:
130
+ self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
131
+ self.model_config["sd_model"],
132
+ cache_dir=self.cache_dir,
133
+ safety_checker=None, # 禁用安全检查以节省内存
134
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
135
+ ).to(self.device)
136
+
137
+ # 优化性能
138
+ if self.device == "cuda":
139
+ try:
140
+ # 启用内存高效注意力
141
+ self.sd_pipeline.enable_xformers_memory_efficient_attention()
142
+ except:
143
+ logger.warning("无法启用xformers,使用回退方案")
144
+
145
+ # 启用注意力切片
146
+ self.sd_pipeline.enable_attention_slicing()
147
+
148
+ logger.info("Stable Diffusion模型加载完成")
149
+ self.load_times["sd"] = time.time() - start_time
150
+ self.last_used["sd"] = time.time()
151
+ except Exception as e:
152
+ logger.error(f"加载SD模型失败: {str(e)}")
153
+ # 尝试回退到更小的模型
154
+ self.model_config["sd_model"] = "runwayml/stable-diffusion-v1-5"
155
+ self.load_sd_pipeline()
156
+
157
+ def load_controlnet_pipeline(self):
158
+ """加载ControlNet管道用于3D试穿"""
159
+ if self.controlnet_pipeline is None:
160
+ start_time = time.time()
161
+ logger.info("正在加载ControlNet模型...")
162
+
163
+ try:
164
+ # 先加载ControlNet模型
165
+ self.controlnet = ControlNetModel.from_pretrained(
166
+ self.model_config["controlnet_model"],
167
+ cache_dir=self.cache_dir,
168
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
169
+ )
170
+
171
+ # 然后创建ControlNet管道
172
+ self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
173
+ self.model_config["sd_model"],
174
+ controlnet=self.controlnet,
175
+ cache_dir=self.cache_dir,
176
+ safety_checker=None,
177
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
178
+ ).to(self.device)
179
+
180
+ # 设置调度器
181
+ self.controlnet_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
182
+ self.controlnet_pipeline.scheduler.config
183
+ )
184
+
185
+ # 优化性能
186
+ if self.device == "cuda":
187
+ try:
188
+ self.controlnet_pipeline.enable_xformers_memory_efficient_attention()
189
+ except:
190
+ logger.warning("无法为ControlNet启用xformers")
191
+
192
+ self.controlnet_pipeline.enable_attention_slicing()
193
+
194
+ logger.info("ControlNet模型加载完成")
195
+ self.load_times["controlnet"] = time.time() - start_time
196
+ self.last_used["controlnet"] = time.time()
197
+ except Exception as e:
198
+ logger.error(f"加载ControlNet模型失败: {str(e)}")
199
 
200
+ def generate_caption(self, image: Image.Image) -> str:
201
+ """为图像生成描述性标题"""
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  try:
203
+ self.load_caption_model()
204
+ self.last_used["caption"] = time.time()
205
+
206
+ # 准备输入
207
+ inputs = self.caption_processor(
208
+ images=image,
209
+ return_tensors="pt"
210
+ ).to(self.device, torch.float16 if self.device == "cuda" else torch.float32)
211
+
212
+ # 生成标题
213
+ output = self.caption_model.generate(**inputs, max_length=50)
214
+ caption = self.caption_processor.decode(output[0], skip_special_tokens=True)
215
+
216
+ logger.info(f"生成的标题: {caption}")
217
+ return caption
218
 
219
+ except Exception as e:
220
+ logger.error(f"生成标题失败: {str(e)}")
221
+ # 返回默认标题
222
+ return "时尚服装设计"
223
+
224
+ def analyze_style(self, image: Image.Image) -> Dict[str, float]:
225
+ """使用CLIP分析图像风格"""
 
 
 
 
 
 
 
 
 
 
 
226
  try:
227
+ self.load_clip_model()
228
+ self.last_used["clip"] = time.time()
 
 
 
229
 
230
+ # 定义风格类别
231
+ style_labels = [
232
+ "商务正装", "休闲风", "运动风", "时尚潮流",
233
+ "复古风", "街头风", "优雅风", "民族风"
234
+ ]
 
 
235
 
236
+ # 准备输入
237
+ inputs = self.clip_processor(
238
+ text=style_labels,
239
+ images=image,
240
+ return_tensors="pt",
241
+ padding=True
242
+ ).to(self.device)
243
+
244
+ # 获取预测
245
+ outputs = self.clip_model(**inputs)
246
+ logits_per_image = outputs.logits_per_image
247
+ probs = logits_per_image.softmax(dim=1).detach().cpu().numpy()[0]
248
+
249
+ # 获取前3个风格
250
+ top3_idx = np.argsort(probs)[-3:][::-1]
251
+ top_styles = {
252
+ style_labels[i]: float(probs[i]) for i in top3_idx
253
+ }
254
+
255
+ logger.info(f"风格分析结果: {top_styles}")
256
+ return top_styles
257
 
258
+ except Exception as e:
259
+ logger.error(f"风格分析失败: {str(e)}")
260
+ # 返回默认风格
261
+ return {"休闲风": 0.8, "时尚潮流": 0.7}
 
 
 
 
262
 
263
+ def generate_image(
264
+ self,
265
+ prompt: str,
266
+ negative_prompt: str = "",
267
+ num_inference_steps: int = 30,
268
+ guidance_scale: float = 7.5,
269
+ height: int = 512,
270
+ width: int = 512
271
+ ) -> Image.Image:
272
+ """根据提示生成设计图像"""
273
  try:
274
+ self.load_sd_pipeline()
275
+ self.last_used["sd"] = time.time()
276
+
277
+ # 生成图像
278
+ with torch.autocast("cuda" if self.device == "cuda" else "cpu"):
279
+ image = self.sd_pipeline(
280
+ prompt=prompt,
281
+ negative_prompt=negative_prompt,
282
+ num_inference_steps=num_inference_steps,
283
+ guidance_scale=guidance_scale,
284
+ height=height,
285
+ width=width
286
+ ).images[0]
287
+
288
+ logger.info(f"成功生成设计图像: {prompt[:50]}...")
289
+ return image
290
+
291
  except Exception as e:
292
+ logger.error(f"生成设计图像失败: {str(e)}")
293
+ # 创建占位图像
294
+ return Image.new('RGB', (512, 512), color=(220, 220, 220))
295
 
296
+ def generate_controlnet_image(
297
+ self,
298
+ image: Image.Image,
299
+ prompt: str,
300
+ negative_prompt: str = "",
301
+ num_inference_steps: int = 35,
302
+ guidance_scale: float = 8.0
303
+ ) -> Image.Image:
304
+ """使用ControlNet生成3D试穿图像"""
305
  try:
306
+ self.load_controlnet_pipeline()
307
+ self.last_used["controlnet"] = time.time()
308
+
309
+ # 生成图像
310
+ with torch.autocast("cuda" if self.device == "cuda" else "cpu"):
311
+ image = self.controlnet_pipeline(
312
  prompt=prompt,
313
+ image=image,
314
  negative_prompt=negative_prompt,
315
  num_inference_steps=num_inference_steps,
316
+ guidance_scale=guidance_scale,
317
+ controlnet_conditioning_scale=0.8
318
  ).images[0]
319
+
320
+ logger.info(f"成功生成3D试穿图像")
321
  return image
322
+
323
  except Exception as e:
324
+ logger.error(f"生成3D试穿图像失败: {str(e)}")
325
+ # 回退到普通SD模型
326
+ return self.generate_image(
327
+ prompt,
328
+ negative_prompt,
329
+ num_inference_steps
330
+ )
331
+
332
+ def unload_model(self, model_type: str):
333
+ """卸载指定类型的模型以释放内存"""
334
+ logger.info(f"卸载模型: {model_type}")
335
+
336
+ if model_type == "caption" and self.caption_model is not None:
337
+ del self.caption_model
338
+ del self.caption_processor
339
+ self.caption_model = None
340
+ self.caption_processor = None
341
+ logger.info("卸载图像描述模型")
342
+
343
+ elif model_type == "clip" and self.clip_model is not None:
344
+ del self.clip_model
345
+ del self.clip_processor
346
+ self.clip_model = None
347
+ self.clip_processor = None
348
+ logger.info("卸载CLIP模型")
349
+
350
+ elif model_type == "sd" and self.sd_pipeline is not None:
351
+ del self.sd_pipeline
352
+ self.sd_pipeline = None
353
+ logger.info("卸载Stable Diffusion模型")
354
+
355
+ elif model_type == "controlnet" and self.controlnet_pipeline is not None:
356
+ del self.controlnet_pipeline
357
+ del self.controlnet
358
+ self.controlnet_pipeline = None
359
+ self.controlnet = None
360
+ logger.info("卸载ControlNet模型")
361
+
362
+ # 清理内存
363
+ self.cleanup_memory()
364
 
365
  def cleanup(self):
366
+ """清理所有模型释放内存"""
367
+ logger.info("清理所有模型释放内存...")
368
+
369
+ # 释放所有模型
370
+ if self.caption_model is not None:
371
+ del self.caption_model
372
+ if self.caption_processor is not None:
373
+ del self.caption_processor
374
+ if self.clip_model is not None:
375
+ del self.clip_model
376
+ if self.clip_processor is not None:
377
+ del self.clip_processor
378
+ if self.sd_pipeline is not None:
379
  del self.sd_pipeline
380
+ if self.controlnet_pipeline is not None:
381
  del self.controlnet_pipeline
382
+ if self.controlnet is not None:
383
+ del self.controlnet
384
+
385
+ # 重置引用
386
+ self.caption_model = None
387
+ self.caption_processor = None
388
+ self.clip_model = None
389
+ self.clip_processor = None
390
+ self.sd_pipeline = None
391
+ self.controlnet_pipeline = None
392
+ self.controlnet = None
393
+
394
+ # ��理内存
395
+ self.cleanup_memory()
396
+ logger.info("内存清理完成")
397
+
398
+ def cleanup_memory(self):
399
+ """执行内存清理操作"""
400
+ # 清理CUDA缓存
401
+ if torch.cuda.is_available():
402
+ torch.cuda.empty_cache()
403
+
404
+ # 执行垃圾回收
405
+ gc.collect()
406
+
407
+ def get_memory_usage(self) -> Dict[str, float]:
408
+ """获取当前内存使用情况"""
409
+ mem_info = {}
410
+
411
+ if torch.cuda.is_available():
412
+ mem_info["gpu_total"] = torch.cuda.get_device_properties(0).total_memory / (1024**3)
413
+ mem_info["gpu_used"] = torch.cuda.memory_allocated() / (1024**3)
414
+ mem_info["gpu_free"] = mem_info["gpu_total"] - mem_info["gpu_used"]
415
+
416
+ return mem_info
417
+
418
+ def get_model_status(self) -> Dict[str, str]:
419
+ """获取模型加载状态"""
420
+ status = {
421
+ "caption_model": "已加载" if self.caption_model else "未加载",
422
+ "clip_model": "已加载" if self.clip_model else "未加载",
423
+ "sd_model": "已加载" if self.sd_pipeline else "未加载",
424
+ "controlnet_model": "已加载" if self.controlnet_pipeline else "未加载"
425
+ }
426
+
427
+ # 添加加载时间信息
428
+ for model in ["caption", "clip", "sd", "controlnet"]:
429
+ if model in self.load_times:
430
+ status[f"{model}_load_time"] = f"{self.load_times[model]:.2f}秒"
431
+ if model in self.last_used:
432
+ mins_ago = (time.time() - self.last_used[model]) / 60
433
+ status[f"{model}_last_used"] = f"{mins_ago:.1f}分钟前"
434
+
435
+ return status
436
+
437
+ def __del__(self):
438
+ """析构函数确保资源释放"""
439
+ self.cleanup()