| """ |
| CompI Phase 2.E: Style Reference/Example Image to AI Art Generation |
| |
| This module implements multimodal AI art generation that combines: |
| - Text prompts with style and mood conditioning |
| - Reference image style transfer and guidance |
| - Image-to-image generation with controllable strength |
| - Support for both local files and web URLs |
| - Advanced style analysis and prompt enhancement |
| |
| Features: |
| - Support for various image formats and web sources |
| - Real-time image analysis and style suggestion |
| - Controllable reference strength for creative flexibility |
| - Comprehensive metadata logging and filename conventions |
| - Batch processing capabilities with multiple variations |
| """ |
|
|
| import os |
| import sys |
| import torch |
| import json |
| from datetime import datetime |
| from typing import Dict, List, Optional, Tuple, Union |
| from pathlib import Path |
| import logging |
|
|
| |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) |
|
|
| from diffusers import StableDiffusionImg2ImgPipeline, StableDiffusionPipeline |
| from PIL import Image |
| import numpy as np |
|
|
| from src.utils.image_utils import ImageProcessor, StyleAnalyzer |
| from src.utils.logging_utils import setup_logger |
| from src.utils.file_utils import ensure_directory_exists, generate_filename |
| from src.config import ( |
| STABLE_DIFFUSION_IMG2IMG_MODEL, |
| OUTPUTS_DIR, |
| DEFAULT_IMAGE_SIZE, |
| DEFAULT_INFERENCE_STEPS, |
| DEFAULT_GUIDANCE_SCALE |
| ) |
|
|
| |
| logger = setup_logger(__name__) |
|
|
| class CompIPhase2ERefImageToImage: |
| """ |
| CompI Phase 2.E: Style Reference/Example Image to AI Art Generation System |
| |
| Combines text prompts with reference image style guidance for enhanced creativity |
| """ |
| |
| def __init__( |
| self, |
| model_name: str = STABLE_DIFFUSION_IMG2IMG_MODEL, |
| device: Optional[str] = None, |
| enable_attention_slicing: bool = True, |
| enable_memory_efficient_attention: bool = True |
| ): |
| """ |
| Initialize the CompI Phase 2.E system |
| |
| Args: |
| model_name: Hugging Face model identifier |
| device: Device to run on ('cuda', 'cpu', or None for auto) |
| enable_attention_slicing: Enable attention slicing for memory efficiency |
| enable_memory_efficient_attention: Enable memory efficient attention |
| """ |
| self.model_name = model_name |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| |
| |
| self.image_processor = ImageProcessor() |
| self.style_analyzer = StyleAnalyzer() |
| |
| |
| self._img2img_pipeline = None |
| self._txt2img_pipeline = None |
| |
| |
| self.enable_attention_slicing = enable_attention_slicing |
| self.enable_memory_efficient_attention = enable_memory_efficient_attention |
| |
| logger.info(f"Initialized CompI Phase 2.E on device: {self.device}") |
| |
| @property |
| def img2img_pipeline(self) -> StableDiffusionImg2ImgPipeline: |
| """Lazy load img2img pipeline""" |
| if self._img2img_pipeline is None: |
| logger.info(f"Loading img2img pipeline: {self.model_name}") |
| self._img2img_pipeline = StableDiffusionImg2ImgPipeline.from_pretrained( |
| self.model_name, |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, |
| safety_checker=None, |
| requires_safety_checker=False |
| ) |
| self._img2img_pipeline = self._img2img_pipeline.to(self.device) |
| |
| if self.enable_attention_slicing: |
| self._img2img_pipeline.enable_attention_slicing() |
| if self.enable_memory_efficient_attention and hasattr(self._img2img_pipeline, 'enable_memory_efficient_attention'): |
| self._img2img_pipeline.enable_memory_efficient_attention() |
| |
| return self._img2img_pipeline |
| |
| @property |
| def txt2img_pipeline(self) -> StableDiffusionPipeline: |
| """Lazy load txt2img pipeline for fallback""" |
| if self._txt2img_pipeline is None: |
| logger.info(f"Loading txt2img pipeline: {self.model_name}") |
| self._txt2img_pipeline = StableDiffusionPipeline.from_pretrained( |
| self.model_name, |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, |
| safety_checker=None, |
| requires_safety_checker=False |
| ) |
| self._txt2img_pipeline = self._txt2img_pipeline.to(self.device) |
| |
| if self.enable_attention_slicing: |
| self._txt2img_pipeline.enable_attention_slicing() |
| if self.enable_memory_efficient_attention and hasattr(self._txt2img_pipeline, 'enable_memory_efficient_attention'): |
| self._txt2img_pipeline.enable_memory_efficient_attention() |
| |
| return self._txt2img_pipeline |
| |
| def load_reference_image( |
| self, |
| source: Union[str, Path, Image.Image], |
| preprocess: bool = True |
| ) -> Optional[Tuple[Image.Image, Dict]]: |
| """ |
| Load and analyze reference image from various sources |
| |
| Args: |
| source: Image source (file path, URL, or PIL Image) |
| preprocess: Whether to preprocess the image |
| |
| Returns: |
| Tuple of (processed_image, analysis_results) or None if failed |
| """ |
| try: |
| |
| if isinstance(source, Image.Image): |
| image = source.convert('RGB') |
| source_info = "PIL Image object" |
| elif isinstance(source, (str, Path)): |
| source_str = str(source) |
| if source_str.startswith(('http://', 'https://')): |
| image = self.image_processor.load_image_from_url(source_str) |
| source_info = f"URL: {source_str}" |
| else: |
| image = self.image_processor.load_image_from_file(source_str) |
| source_info = f"File: {source_str}" |
| |
| if image is None: |
| return None |
| else: |
| logger.error(f"Unsupported source type: {type(source)}") |
| return None |
| |
| |
| if preprocess: |
| image = self.image_processor.preprocess_image(image, DEFAULT_IMAGE_SIZE) |
| |
| |
| properties = self.image_processor.analyze_image_properties(image) |
| style_suggestions = self.style_analyzer.suggest_style_keywords(properties) |
| image_hash = self.image_processor.generate_image_hash(image) |
| |
| analysis = { |
| 'source': source_info, |
| 'properties': properties, |
| 'style_suggestions': style_suggestions, |
| 'hash': image_hash, |
| 'processed_size': image.size |
| } |
| |
| logger.info(f"Successfully loaded and analyzed reference image: {analysis}") |
| return image, analysis |
| |
| except Exception as e: |
| logger.error(f"Error loading reference image: {e}") |
| return None |
| |
| def enhance_prompt_with_style( |
| self, |
| base_prompt: str, |
| style: str = "", |
| mood: str = "", |
| style_suggestions: List[str] = None |
| ) -> str: |
| """ |
| Enhance prompt with style information from reference image |
| |
| Args: |
| base_prompt: Base text prompt |
| style: Additional style keywords |
| mood: Mood/atmosphere keywords |
| style_suggestions: Suggested keywords from image analysis |
| |
| Returns: |
| Enhanced prompt string |
| """ |
| try: |
| prompt_parts = [base_prompt.strip()] |
| |
| |
| if style.strip(): |
| prompt_parts.append(style.strip()) |
| |
| |
| if mood.strip(): |
| prompt_parts.append(mood.strip()) |
| |
| |
| if style_suggestions: |
| |
| top_suggestions = style_suggestions[:3] |
| prompt_parts.extend(top_suggestions) |
| |
| enhanced_prompt = ", ".join(prompt_parts) |
| logger.info(f"Enhanced prompt: {enhanced_prompt}") |
| return enhanced_prompt |
| |
| except Exception as e: |
| logger.error(f"Error enhancing prompt: {e}") |
| return base_prompt |
|
|
| def generate_with_reference( |
| self, |
| prompt: str, |
| reference_image: Image.Image, |
| style: str = "", |
| mood: str = "", |
| strength: float = 0.5, |
| num_images: int = 1, |
| num_inference_steps: int = DEFAULT_INFERENCE_STEPS, |
| guidance_scale: float = DEFAULT_GUIDANCE_SCALE, |
| seed: Optional[int] = None, |
| style_suggestions: List[str] = None |
| ) -> List[Dict]: |
| """ |
| Generate images using reference image guidance |
| |
| Args: |
| prompt: Text prompt |
| reference_image: Reference PIL Image |
| style: Style keywords |
| mood: Mood keywords |
| strength: Reference strength (0.0-1.0, higher = closer to reference) |
| num_images: Number of images to generate |
| num_inference_steps: Number of denoising steps |
| guidance_scale: Classifier-free guidance scale |
| seed: Random seed for reproducibility |
| style_suggestions: Style suggestions from image analysis |
| |
| Returns: |
| List of generation results with metadata |
| """ |
| try: |
| |
| enhanced_prompt = self.enhance_prompt_with_style( |
| prompt, style, mood, style_suggestions |
| ) |
|
|
| results = [] |
|
|
| for i in range(num_images): |
| |
| if seed is not None: |
| current_seed = seed + i |
| else: |
| current_seed = torch.seed() |
|
|
| generator = torch.Generator(device=self.device).manual_seed(current_seed) |
|
|
| |
| logger.info(f"Generating image {i+1}/{num_images} with reference guidance") |
|
|
| with torch.autocast(self.device) if self.device == "cuda" else torch.no_grad(): |
| result = self.img2img_pipeline( |
| prompt=enhanced_prompt, |
| image=reference_image, |
| strength=strength, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| generator=generator |
| ) |
|
|
| generated_image = result.images[0] |
|
|
| |
| metadata = { |
| 'prompt': prompt, |
| 'enhanced_prompt': enhanced_prompt, |
| 'style': style, |
| 'mood': mood, |
| 'strength': strength, |
| 'num_inference_steps': num_inference_steps, |
| 'guidance_scale': guidance_scale, |
| 'seed': current_seed, |
| 'model': self.model_name, |
| 'generation_type': 'img2img_reference', |
| 'timestamp': datetime.now().isoformat(), |
| 'device': self.device, |
| 'reference_size': reference_image.size, |
| 'output_size': generated_image.size, |
| 'style_suggestions': style_suggestions or [] |
| } |
|
|
| results.append({ |
| 'image': generated_image, |
| 'metadata': metadata, |
| 'index': i |
| }) |
|
|
| logger.info(f"Successfully generated {len(results)} images with reference guidance") |
| return results |
|
|
| except Exception as e: |
| logger.error(f"Error generating images with reference: {e}") |
| return [] |
|
|
| def generate_without_reference( |
| self, |
| prompt: str, |
| style: str = "", |
| mood: str = "", |
| num_images: int = 1, |
| num_inference_steps: int = DEFAULT_INFERENCE_STEPS, |
| guidance_scale: float = DEFAULT_GUIDANCE_SCALE, |
| seed: Optional[int] = None |
| ) -> List[Dict]: |
| """ |
| Generate images without reference (fallback to text-to-image) |
| |
| Args: |
| prompt: Text prompt |
| style: Style keywords |
| mood: Mood keywords |
| num_images: Number of images to generate |
| num_inference_steps: Number of denoising steps |
| guidance_scale: Classifier-free guidance scale |
| seed: Random seed for reproducibility |
| |
| Returns: |
| List of generation results with metadata |
| """ |
| try: |
| |
| enhanced_prompt = self.enhance_prompt_with_style(prompt, style, mood) |
|
|
| results = [] |
|
|
| for i in range(num_images): |
| |
| if seed is not None: |
| current_seed = seed + i |
| else: |
| current_seed = torch.seed() |
|
|
| generator = torch.Generator(device=self.device).manual_seed(current_seed) |
|
|
| |
| logger.info(f"Generating image {i+1}/{num_images} without reference") |
|
|
| with torch.autocast(self.device) if self.device == "cuda" else torch.no_grad(): |
| result = self.txt2img_pipeline( |
| prompt=enhanced_prompt, |
| height=DEFAULT_IMAGE_SIZE[1], |
| width=DEFAULT_IMAGE_SIZE[0], |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| generator=generator |
| ) |
|
|
| generated_image = result.images[0] |
|
|
| |
| metadata = { |
| 'prompt': prompt, |
| 'enhanced_prompt': enhanced_prompt, |
| 'style': style, |
| 'mood': mood, |
| 'num_inference_steps': num_inference_steps, |
| 'guidance_scale': guidance_scale, |
| 'seed': current_seed, |
| 'model': self.model_name, |
| 'generation_type': 'txt2img_fallback', |
| 'timestamp': datetime.now().isoformat(), |
| 'device': self.device, |
| 'output_size': generated_image.size |
| } |
|
|
| results.append({ |
| 'image': generated_image, |
| 'metadata': metadata, |
| 'index': i |
| }) |
|
|
| logger.info(f"Successfully generated {len(results)} images without reference") |
| return results |
|
|
| except Exception as e: |
| logger.error(f"Error generating images without reference: {e}") |
| return [] |
|
|
| def save_results( |
| self, |
| results: List[Dict], |
| output_dir: Path = OUTPUTS_DIR, |
| reference_info: Optional[Dict] = None |
| ) -> List[str]: |
| """ |
| Save generation results with comprehensive metadata |
| |
| Args: |
| results: List of generation results |
| output_dir: Output directory |
| reference_info: Reference image information |
| |
| Returns: |
| List of saved file paths |
| """ |
| try: |
| ensure_directory_exists(output_dir) |
| saved_files = [] |
|
|
| for result in results: |
| image = result['image'] |
| metadata = result['metadata'] |
| index = result['index'] |
|
|
| |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| prompt_slug = "_".join(metadata['prompt'].lower().split()[:5]) |
| style_slug = metadata.get('style', '').replace(' ', '')[:10] |
| mood_slug = metadata.get('mood', '').replace(' ', '')[:10] |
|
|
| |
| ref_indicator = "REFIMG" if metadata['generation_type'] == 'img2img_reference' else "NOREFIMG" |
|
|
| filename = f"{prompt_slug}_{style_slug}_{mood_slug}_{timestamp}_seed{metadata['seed']}_{ref_indicator}_v{index+1}.png" |
| filepath = output_dir / filename |
|
|
| |
| image.save(filepath) |
|
|
| |
| if reference_info: |
| metadata['reference_info'] = reference_info |
|
|
| |
| metadata_filename = filepath.stem + "_metadata.json" |
| metadata_filepath = output_dir / metadata_filename |
|
|
| with open(metadata_filepath, 'w') as f: |
| json.dump(metadata, f, indent=2, default=str) |
|
|
| saved_files.extend([str(filepath), str(metadata_filepath)]) |
| logger.info(f"Saved: {filepath}") |
|
|
| return saved_files |
|
|
| except Exception as e: |
| logger.error(f"Error saving results: {e}") |
| return [] |
|
|
| def generate_batch( |
| self, |
| prompt: str, |
| reference_source: Optional[Union[str, Path, Image.Image]] = None, |
| style: str = "", |
| mood: str = "", |
| strength: float = 0.5, |
| num_images: int = 1, |
| num_inference_steps: int = DEFAULT_INFERENCE_STEPS, |
| guidance_scale: float = DEFAULT_GUIDANCE_SCALE, |
| seed: Optional[int] = None, |
| save_results: bool = True, |
| output_dir: Path = OUTPUTS_DIR |
| ) -> Dict: |
| """ |
| Complete batch generation pipeline |
| |
| Args: |
| prompt: Text prompt |
| reference_source: Reference image source (file, URL, or PIL Image) |
| style: Style keywords |
| mood: Mood keywords |
| strength: Reference strength (only used if reference provided) |
| num_images: Number of images to generate |
| num_inference_steps: Number of denoising steps |
| guidance_scale: Classifier-free guidance scale |
| seed: Random seed for reproducibility |
| save_results: Whether to save results to disk |
| output_dir: Output directory for saved files |
| |
| Returns: |
| Dictionary with results and metadata |
| """ |
| try: |
| logger.info(f"Starting batch generation: {num_images} images") |
|
|
| reference_image = None |
| reference_info = None |
| style_suggestions = [] |
|
|
| |
| if reference_source is not None: |
| ref_result = self.load_reference_image(reference_source) |
| if ref_result: |
| reference_image, reference_info = ref_result |
| style_suggestions = reference_info.get('style_suggestions', []) |
| logger.info(f"Using reference image with suggestions: {style_suggestions}") |
| else: |
| logger.warning("Failed to load reference image, falling back to text-only generation") |
|
|
| |
| if reference_image is not None: |
| results = self.generate_with_reference( |
| prompt=prompt, |
| reference_image=reference_image, |
| style=style, |
| mood=mood, |
| strength=strength, |
| num_images=num_images, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| seed=seed, |
| style_suggestions=style_suggestions |
| ) |
| else: |
| results = self.generate_without_reference( |
| prompt=prompt, |
| style=style, |
| mood=mood, |
| num_images=num_images, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| seed=seed |
| ) |
|
|
| |
| saved_files = [] |
| if save_results and results: |
| saved_files = self.save_results(results, output_dir, reference_info) |
|
|
| |
| batch_result = { |
| 'results': results, |
| 'reference_info': reference_info, |
| 'saved_files': saved_files, |
| 'generation_summary': { |
| 'total_images': len(results), |
| 'prompt': prompt, |
| 'style': style, |
| 'mood': mood, |
| 'has_reference': reference_image is not None, |
| 'style_suggestions': style_suggestions, |
| 'timestamp': datetime.now().isoformat() |
| } |
| } |
|
|
| logger.info(f"Batch generation complete: {len(results)} images generated") |
| return batch_result |
|
|
| except Exception as e: |
| logger.error(f"Error in batch generation: {e}") |
| return { |
| 'results': [], |
| 'reference_info': None, |
| 'saved_files': [], |
| 'error': str(e) |
| } |
|
|