| |
| |
| |
| import os |
| import gradio as gr |
| from huggingface_hub import login |
| from diffusers import FluxPipeline |
| import torch |
| from PIL import Image |
| import fitz |
| import sentencepiece |
|
|
| def load_pdf(pdf_path): |
| """Traite le texte d'un fichier PDF""" |
| if pdf_path is None: |
| return None |
| text = "" |
| try: |
| doc = fitz.open(pdf_path) |
| for page in doc: |
| text += page.get_text() |
| doc.close() |
| return text |
| except Exception as e: |
| print(f"Erreur lors de la lecture du PDF: {str(e)}") |
| return None |
|
|
| class FluxGenerator: |
| def __init__(self): |
| self.token = os.getenv('Authentification_HF') |
| if not self.token: |
| raise ValueError("Token d'authentification HuggingFace non trouvé") |
| login(self.token) |
| self.pipeline = None |
| self.load_model() |
|
|
| def load_model(self): |
| """Charge le modèle FLUX avec des paramètres optimisés""" |
| try: |
| print("Chargement du modèle FLUX...") |
| self.pipeline = FluxPipeline.from_pretrained( |
| "black-forest-labs/FLUX.1-schnell", |
| revision="refs/pr/1", |
| torch_dtype=torch.bfloat16 |
| ) |
| self.pipeline.enable_model_cpu_offload() |
| self.pipeline.tokenizer.add_prefix_space = False |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.pipeline.to(device) |
| print(f"Utilisation de l'appareil: {device}") |
| print("Modèle FLUX chargé avec succès!") |
| except Exception as e: |
| print(f"Erreur lors du chargement du modèle: {str(e)}") |
| raise |
|
|
| def generate_image(self, prompt, reference_image=None, pdf_file=None): |
| """Génère une image à partir d'un prompt et optionnellement une référence""" |
| try: |
| |
| if pdf_file is not None: |
| pdf_text = load_pdf(pdf_file) |
| if pdf_text: |
| prompt = f"{prompt}\nContexte du PDF:\n{pdf_text}" |
|
|
| |
| image = self.pipeline( |
| prompt=prompt, |
| num_inference_steps=30, |
| guidance_scale=0.0, |
| max_sequence_length=256, |
| generator=torch.Generator("cpu").manual_seed(0) |
| ).images[0] |
|
|
| return image |
|
|
| except Exception as e: |
| print(f"Erreur lors de la génération de l'image: {str(e)}") |
| return None |
|
|
| |
| generator = FluxGenerator() |
|
|
| def generate(prompt, reference_file): |
| """Fonction de génération pour l'interface Gradio""" |
| try: |
| |
| if reference_file is not None: |
| file_type = reference_file.name.split('.')[-1].lower() |
| if file_type in ['pdf']: |
| return generator.generate_image(prompt, pdf_file=reference_file.name) |
| elif file_type in ['png', 'jpg', 'jpeg']: |
| return generator.generate_image(prompt, reference_image=reference_file.name) |
|
|
| |
| return generator.generate_image(prompt) |
|
|
| except Exception as e: |
| print(f"Erreur: {str(e)}") |
| return None |
|
|
| |
| demo = gr.Interface( |
| fn=generate, |
| inputs=[ |
| gr.Textbox(label="Prompt", placeholder="Décrivez l'image que vous souhaitez générer..."), |
| gr.File(label="Image ou PDF de référence (optionnel)", type="binary") |
| ], |
| outputs=gr.Image(label="Image générée"), |
| title="Test du modèle FLUX", |
| description="Interface simple pour tester la génération d'images avec FLUX" |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|