| |
| import os |
| import gradio as gr |
| from huggingface_hub import login |
| from diffusers import FluxPipeline |
| import torch |
| from PIL import Image |
| import fitz |
| import gc |
| import psutil |
|
|
| |
| torch.set_default_device("cpu") |
| torch.set_num_threads(2) |
| torch.set_grad_enabled(False) |
|
|
| def get_memory_usage(): |
| """Retourne l'utilisation actuelle de la mémoire en GB""" |
| process = psutil.Process(os.getpid()) |
| return process.memory_info().rss / 1024 / 1024 / 1024 |
|
|
| def load_pdf(pdf_path): |
| """Traite le texte d'un fichier PDF""" |
| if pdf_path is None: |
| return None |
| text = "" |
| try: |
| doc = fitz.open(pdf_path) |
| for page in doc: |
| text += page.get_text() |
| doc.close() |
| return text |
| except Exception as e: |
| print(f"Erreur lors de la lecture du PDF: {str(e)}") |
| return None |
|
|
| class FluxGenerator: |
| def __init__(self): |
| self.token = os.getenv('Authentification_HF') |
| if not self.token: |
| raise ValueError("Token d'authentification HuggingFace non trouvé") |
| login(self.token) |
| self.pipeline = None |
| self.device = "cpu" |
| self.load_model() |
|
|
| def load_model(self): |
| """Charge le modèle FLUX avec des paramètres optimisés pour faible mémoire""" |
| try: |
| print("Chargement du modèle FLUX avec optimisations mémoire...") |
| print(f"Mémoire utilisée avant chargement: {get_memory_usage():.2f} GB") |
|
|
| |
| model_kwargs = { |
| "low_cpu_mem_usage": True, |
| "torch_dtype": torch.float8, |
| "use_safetensors": True, |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.pipeline = FluxPipeline.from_pretrained( |
| "black-forest-labs/FLUX.1-schnell", |
| revision="refs/pr/1", |
| device_map="balanced", |
| torch_dtype=torch.float8, |
| use_safetensors=True |
| ) |
| |
| |
| |
| |
|
|
| |
| self.pipeline.enable_sequential_cpu_offload() |
| self.pipeline.enable_attention_slicing(slice_size=1) |
|
|
| |
| gc.collect() |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
| print(f"Mémoire utilisée après chargement: {get_memory_usage():.2f} GB") |
| print("Modèle FLUX chargé avec succès en mode basse consommation!") |
|
|
| except Exception as e: |
| print(f"Erreur lors du chargement du modèle: {str(e)}") |
| raise |
|
|
| def generate_image(self, prompt, reference_image=None, pdf_file=None): |
| """Génère une image avec paramètres optimisés pour la mémoire""" |
| try: |
| if pdf_file is not None: |
| pdf_text = load_pdf(pdf_file) |
| if pdf_text: |
| prompt = f"{prompt}\nContexte du PDF:\n{pdf_text}" |
|
|
| |
| with torch.no_grad(): |
| image = self.pipeline( |
| prompt=prompt, |
| num_inference_steps=4, |
| height=512, |
| width=512, |
| guidance_scale=0.0, |
| max_sequence_length=128, |
| generator=torch.Generator(device=self.device).manual_seed(0) |
| ).images[0] |
|
|
| |
| gc.collect() |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
| return image |
|
|
| except Exception as e: |
| print(f"Erreur lors de la génération de l'image: {str(e)}") |
| return None |
|
|
| |
| generator = None |
|
|
| def generate(prompt, reference_file): |
| """Fonction de génération pour l'interface Gradio""" |
| global generator |
| try: |
| |
| if generator is None: |
| generator = FluxGenerator() |
|
|
| |
| if reference_file is not None: |
| if isinstance(reference_file, dict): |
| file_path = reference_file.name |
| else: |
| file_path = reference_file |
| |
| file_type = file_path.split('.')[-1].lower() |
| if file_type in ['pdf']: |
| return generator.generate_image(prompt, pdf_file=file_path) |
| elif file_type in ['png', 'jpg', 'jpeg']: |
| return generator.generate_image(prompt, reference_image=file_path) |
|
|
| return generator.generate_image(prompt) |
|
|
| except Exception as e: |
| print(f"Erreur détaillée: {str(e)}") |
| return None |
|
|
| |
| demo = gr.Interface( |
| fn=generate, |
| inputs=[ |
| gr.Textbox(label="Prompt", placeholder="Décrivez l'image que vous souhaitez générer..."), |
| gr.File(label="Image ou PDF de référence (optionnel)", type="binary") |
| ], |
| outputs=gr.Image(label="Image générée"), |
| title="FLUX (Mode économique)", |
| description="Génération d'images optimisée pour systèmes à ressources limitées" |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |