annoyingpixel commited on
Commit
a4355c4
Β·
verified Β·
1 Parent(s): 7c02427

Upload flux_space_lora_manager.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. flux_space_lora_manager.py +303 -0
flux_space_lora_manager.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ LoRA Manager for FLUX.1 Space - Advanced LoRA handling and integration
4
+ """
5
+
6
+ import torch
7
+ from safetensors.torch import load_file, save_file
8
+ import os
9
+ import json
10
+ from typing import Dict, List, Optional, Tuple
11
+ import numpy as np
12
+
13
+ class FluxLoRAManager:
14
+ """
15
+ Advanced LoRA manager for FLUX models
16
+ """
17
+
18
+ def __init__(self):
19
+ self.loaded_loras = {}
20
+ self.lora_metadata = {}
21
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+ def load_lora_file(self, lora_path: str, lora_name: str = None) -> Dict:
24
+ """
25
+ Load a LoRA file and extract metadata
26
+ """
27
+ try:
28
+ print(f"πŸ”„ Loading LoRA file: {lora_path}")
29
+
30
+ # Load LoRA weights
31
+ lora_state_dict = load_file(lora_path)
32
+
33
+ # Extract metadata
34
+ metadata = self._extract_lora_metadata(lora_path, lora_state_dict)
35
+
36
+ # Generate name if not provided
37
+ if lora_name is None:
38
+ lora_name = os.path.splitext(os.path.basename(lora_path))[0]
39
+
40
+ # Store LoRA
41
+ self.loaded_loras[lora_name] = {
42
+ 'path': lora_path,
43
+ 'state_dict': lora_state_dict,
44
+ 'metadata': metadata,
45
+ 'strength': 1.0,
46
+ 'active': False
47
+ }
48
+
49
+ print(f"βœ… LoRA '{lora_name}' loaded successfully")
50
+ print(f"πŸ“Š Metadata: {metadata}")
51
+
52
+ return {
53
+ 'name': lora_name,
54
+ 'metadata': metadata,
55
+ 'success': True
56
+ }
57
+
58
+ except Exception as e:
59
+ print(f"❌ Error loading LoRA: {e}")
60
+ return {
61
+ 'name': lora_name,
62
+ 'error': str(e),
63
+ 'success': False
64
+ }
65
+
66
+ def _extract_lora_metadata(self, lora_path: str, state_dict: Dict) -> Dict:
67
+ """
68
+ Extract metadata from LoRA file
69
+ """
70
+ metadata = {
71
+ 'filename': os.path.basename(lora_path),
72
+ 'file_size_mb': os.path.getsize(lora_path) / (1024 * 1024),
73
+ 'tensor_count': len(state_dict),
74
+ 'tensor_names': list(state_dict.keys()),
75
+ 'base_model': 'unknown',
76
+ 'training_info': {}
77
+ }
78
+
79
+ # Try to load JSON metadata if it exists
80
+ json_path = lora_path.replace('.safetensors', '.json')
81
+ if os.path.exists(json_path):
82
+ try:
83
+ with open(json_path, 'r') as f:
84
+ json_metadata = json.load(f)
85
+ metadata.update(json_metadata)
86
+ except:
87
+ pass
88
+
89
+ # Analyze tensor structure to determine base model
90
+ if any('double_blocks' in key for key in state_dict.keys()):
91
+ metadata['base_model'] = 'FLUX'
92
+ elif any('unet' in key for key in state_dict.keys()):
93
+ metadata['base_model'] = 'Stable Diffusion'
94
+
95
+ return metadata
96
+
97
+ def apply_lora_to_model(self, lora_name: str, model_pipeline, strength: float = 1.0) -> bool:
98
+ """
99
+ Apply a LoRA to a model pipeline
100
+ """
101
+ if lora_name not in self.loaded_loras:
102
+ print(f"❌ LoRA '{lora_name}' not loaded")
103
+ return False
104
+
105
+ try:
106
+ print(f"πŸ”„ Applying LoRA '{lora_name}' with strength {strength}")
107
+
108
+ lora_data = self.loaded_loras[lora_name]
109
+ state_dict = lora_data['state_dict']
110
+
111
+ # Apply LoRA weights with strength
112
+ for key, value in state_dict.items():
113
+ if key in model_pipeline.state_dict():
114
+ # Scale the LoRA weights by strength
115
+ scaled_value = value * strength
116
+ model_pipeline.state_dict()[key].copy_(scaled_value)
117
+
118
+ # Update LoRA status
119
+ lora_data['strength'] = strength
120
+ lora_data['active'] = True
121
+
122
+ print(f"βœ… LoRA '{lora_name}' applied successfully")
123
+ return True
124
+
125
+ except Exception as e:
126
+ print(f"❌ Error applying LoRA: {e}")
127
+ return False
128
+
129
+ def remove_lora_from_model(self, lora_name: str, model_pipeline) -> bool:
130
+ """
131
+ Remove a LoRA from a model pipeline
132
+ """
133
+ if lora_name not in self.loaded_loras:
134
+ print(f"❌ LoRA '{lora_name}' not loaded")
135
+ return False
136
+
137
+ try:
138
+ print(f"πŸ”„ Removing LoRA '{lora_name}'")
139
+
140
+ lora_data = self.loaded_loras[lora_name]
141
+ state_dict = lora_data['state_dict']
142
+
143
+ # Remove LoRA weights (set to zero)
144
+ for key, value in state_dict.items():
145
+ if key in model_pipeline.state_dict():
146
+ model_pipeline.state_dict()[key].zero_()
147
+
148
+ # Update LoRA status
149
+ lora_data['active'] = False
150
+
151
+ print(f"βœ… LoRA '{lora_name}' removed successfully")
152
+ return True
153
+
154
+ except Exception as e:
155
+ print(f"❌ Error removing LoRA: {e}")
156
+ return False
157
+
158
+ def blend_loras(self, lora_names: List[str], weights: List[float]) -> Dict:
159
+ """
160
+ Blend multiple LoRAs with specified weights
161
+ """
162
+ if len(lora_names) != len(weights):
163
+ print("❌ Number of LoRAs and weights must match")
164
+ return {'success': False, 'error': 'Mismatched arrays'}
165
+
166
+ try:
167
+ print(f"πŸ”„ Blending LoRAs: {lora_names}")
168
+ print(f"πŸ“Š Weights: {weights}")
169
+
170
+ # Normalize weights
171
+ total_weight = sum(weights)
172
+ normalized_weights = [w / total_weight for w in weights]
173
+
174
+ # Get all unique tensor keys
175
+ all_keys = set()
176
+ for lora_name in lora_names:
177
+ if lora_name in self.loaded_loras:
178
+ all_keys.update(self.loaded_loras[lora_name]['state_dict'].keys())
179
+
180
+ # Create blended state dict
181
+ blended_state_dict = {}
182
+ for key in all_keys:
183
+ blended_tensor = None
184
+ for lora_name, weight in zip(lora_names, normalized_weights):
185
+ if lora_name in self.loaded_loras:
186
+ lora_state_dict = self.loaded_loras[lora_name]['state_dict']
187
+ if key in lora_state_dict:
188
+ if blended_tensor is None:
189
+ blended_tensor = lora_state_dict[key] * weight
190
+ else:
191
+ blended_tensor += lora_state_dict[key] * weight
192
+
193
+ if blended_tensor is not None:
194
+ blended_state_dict[key] = blended_tensor
195
+
196
+ # Create blended LoRA name
197
+ blended_name = f"blended_{'_'.join(lora_names)}"
198
+
199
+ # Store blended LoRA
200
+ self.loaded_loras[blended_name] = {
201
+ 'path': 'blended',
202
+ 'state_dict': blended_state_dict,
203
+ 'metadata': {
204
+ 'blended_from': lora_names,
205
+ 'weights': normalized_weights,
206
+ 'base_model': 'FLUX'
207
+ },
208
+ 'strength': 1.0,
209
+ 'active': False
210
+ }
211
+
212
+ print(f"βœ… Blended LoRA '{blended_name}' created successfully")
213
+ return {
214
+ 'success': True,
215
+ 'blended_name': blended_name,
216
+ 'tensor_count': len(blended_state_dict)
217
+ }
218
+
219
+ except Exception as e:
220
+ print(f"❌ Error blending LoRAs: {e}")
221
+ return {'success': False, 'error': str(e)}
222
+
223
+ def get_lora_info(self, lora_name: str) -> Dict:
224
+ """
225
+ Get detailed information about a loaded LoRA
226
+ """
227
+ if lora_name not in self.loaded_loras:
228
+ return {'error': f"LoRA '{lora_name}' not found"}
229
+
230
+ lora_data = self.loaded_loras[lora_name]
231
+ return {
232
+ 'name': lora_name,
233
+ 'path': lora_data['path'],
234
+ 'active': lora_data['active'],
235
+ 'strength': lora_data['strength'],
236
+ 'metadata': lora_data['metadata']
237
+ }
238
+
239
+ def get_all_loras_info(self) -> List[Dict]:
240
+ """
241
+ Get information about all loaded LoRAs
242
+ """
243
+ return [self.get_lora_info(name) for name in self.loaded_loras.keys()]
244
+
245
+ def save_blended_lora(self, blended_name: str, output_path: str) -> bool:
246
+ """
247
+ Save a blended LoRA to file
248
+ """
249
+ if blended_name not in self.loaded_loras:
250
+ print(f"❌ Blended LoRA '{blended_name}' not found")
251
+ return False
252
+
253
+ try:
254
+ print(f"πŸ’Ύ Saving blended LoRA to: {output_path}")
255
+
256
+ lora_data = self.loaded_loras[blended_name]
257
+ state_dict = lora_data['state_dict']
258
+ metadata = lora_data['metadata']
259
+
260
+ # Save safetensors file
261
+ save_file(state_dict, output_path, metadata=metadata)
262
+
263
+ # Save JSON metadata
264
+ json_path = output_path.replace('.safetensors', '.json')
265
+ with open(json_path, 'w') as f:
266
+ json.dump(metadata, f, indent=2)
267
+
268
+ print(f"βœ… Blended LoRA saved successfully")
269
+ return True
270
+
271
+ except Exception as e:
272
+ print(f"❌ Error saving blended LoRA: {e}")
273
+ return False
274
+
275
+ # Utility functions for Gradio integration
276
+ def create_lora_manager():
277
+ """
278
+ Create and return a LoRA manager instance
279
+ """
280
+ return FluxLoRAManager()
281
+
282
+ def validate_lora_file(file_path: str) -> Dict:
283
+ """
284
+ Validate a LoRA file before loading
285
+ """
286
+ try:
287
+ if not os.path.exists(file_path):
288
+ return {'valid': False, 'error': 'File not found'}
289
+
290
+ if not file_path.endswith('.safetensors'):
291
+ return {'valid': False, 'error': 'File must be .safetensors format'}
292
+
293
+ # Try to load the file
294
+ state_dict = load_file(file_path)
295
+
296
+ return {
297
+ 'valid': True,
298
+ 'tensor_count': len(state_dict),
299
+ 'file_size_mb': os.path.getsize(file_path) / (1024 * 1024)
300
+ }
301
+
302
+ except Exception as e:
303
+ return {'valid': False, 'error': str(e)}