Humphreykowl commited on
Commit
479bc48
·
verified ·
1 Parent(s): fa7fc80

Update models/model_manager.py

Browse files
Files changed (1) hide show
  1. models/model_manager.py +130 -0
models/model_manager.py CHANGED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/model_manager.py
2
+ from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
3
+ from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline
4
+ import torch
5
+
6
+ class ModelManager:
7
+ def __init__(self):
8
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ self.init_models()
10
+
11
+ def init_models(self):
12
+ print("正在加载模型...")
13
+
14
+ # 修复1: 使用兼容的 BLIP 模型
15
+ print("加载图像理解模型...")
16
+ self.blip_processor = BlipProcessor.from_pretrained(
17
+ "Salesforce/blip-image-captioning-base",
18
+ # 添加兼容性参数
19
+ )
20
+ self.blip_model = BlipForConditionalGeneration.from_pretrained(
21
+ "Salesforce/blip-image-captioning-base",
22
+ ).to(self.device)
23
+
24
+ # 修复2: 文本生成模型 - 添加错误处理
25
+ print("加载文本生成模型...")
26
+ try:
27
+ self.text_generator = pipeline(
28
+ "text-generation",
29
+ model="microsoft/DialoGPT-medium",
30
+ device=0 if self.device == "cuda" else -1
31
+ )
32
+ except Exception as e:
33
+ print(f"DialoGPT 加载失败,使用备选模型: {e}")
34
+ self.text_generator = pipeline(
35
+ "text-generation",
36
+ model="gpt2",
37
+ device=0 if self.device == "cuda" else -1
38
+ )
39
+
40
+ # 修复3: Stable Diffusion 模型 - 添加内存优化
41
+ print("加载 Stable Diffusion 模型...")
42
+ self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
43
+ "runwayml/stable-diffusion-v1-5",
44
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
45
+ use_safetensors=True,
46
+ variant="fp16" if self.device == "cuda" else None
47
+ )
48
+
49
+ # 内存优化
50
+ if self.device == "cuda":
51
+ self.sd_pipeline.enable_model_cpu_offload()
52
+ self.sd_pipeline.enable_xformers_memory_efficient_attention()
53
+ else:
54
+ self.sd_pipeline = self.sd_pipeline.to(self.device)
55
+
56
+ # 修复4: ControlNet 模型 - 添加错误处理
57
+ print("加载 ControlNet 模型...")
58
+ try:
59
+ self.controlnet = ControlNetModel.from_pretrained(
60
+ "lllyasviel/sd-controlnet-openpose",
61
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
62
+ use_safetensors=True
63
+ )
64
+
65
+ self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
66
+ "runwayml/stable-diffusion-v1-5",
67
+ controlnet=self.controlnet,
68
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
69
+ use_safetensors=True,
70
+ variant="fp16" if self.device == "cuda" else None
71
+ )
72
+
73
+ if self.device == "cuda":
74
+ self.controlnet_pipeline.enable_model_cpu_offload()
75
+ self.controlnet_pipeline.enable_xformers_memory_efficient_attention()
76
+ else:
77
+ self.controlnet_pipeline = self.controlnet_pipeline.to(self.device)
78
+
79
+ except Exception as e:
80
+ print(f"ControlNet 加载失败: {e}")
81
+ self.controlnet = None
82
+ self.controlnet_pipeline = None
83
+
84
+ print("所有模型加载完成!")
85
+
86
+ def generate_caption(self, image):
87
+ """生成图像描述"""
88
+ inputs = self.blip_processor(image, return_tensors="pt").to(self.device)
89
+ with torch.no_grad():
90
+ out = self.blip_model.generate(**inputs, max_length=50)
91
+ return self.blip_processor.decode(out[0], skip_special_tokens=True)
92
+
93
+ def generate_text(self, prompt, max_length=100):
94
+ """生成文本"""
95
+ try:
96
+ result = self.text_generator(
97
+ prompt,
98
+ max_length=max_length,
99
+ num_return_sequences=1,
100
+ temperature=0.7,
101
+ do_sample=True,
102
+ pad_token_id=self.text_generator.tokenizer.eos_token_id
103
+ )
104
+ return result[0]['generated_text']
105
+ except Exception as e:
106
+ print(f"文本生成错误: {e}")
107
+ return f"生成失败: {str(e)}"
108
+
109
+ def generate_image(self, prompt, negative_prompt="", num_inference_steps=20):
110
+ """生成图像"""
111
+ try:
112
+ with torch.autocast(self.device):
113
+ image = self.sd_pipeline(
114
+ prompt=prompt,
115
+ negative_prompt=negative_prompt,
116
+ num_inference_steps=num_inference_steps,
117
+ guidance_scale=7.5
118
+ ).images[0]
119
+ return image
120
+ except Exception as e:
121
+ print(f"图像生成错误: {e}")
122
+ return None
123
+
124
+ def cleanup(self):
125
+ """清理 GPU ���存"""
126
+ if hasattr(self, 'sd_pipeline'):
127
+ del self.sd_pipeline
128
+ if hasattr(self, 'controlnet_pipeline'):
129
+ del self.controlnet_pipeline
130
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None