| |
| |
|
|
| """ |
| DiffSketcher endpoint implementation for Hugging Face. |
| """ |
|
|
| import os |
| import sys |
| import io |
| import base64 |
| import torch |
| import numpy as np |
| from PIL import Image |
| import cairosvg |
| import tempfile |
| import subprocess |
| import shutil |
| from pathlib import Path |
|
|
| class DiffSketcherEndpoint: |
| def __init__(self, model_dir): |
| """Initialize the DiffSketcher endpoint""" |
| self.model_dir = model_dir |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Initializing DiffSketcher endpoint on device: {self.device}") |
| |
| |
| self.temp_dir = tempfile.mkdtemp() |
| self.temp_model_dir = Path(self.temp_dir) / "DiffSketcher" |
| |
| |
| if not os.path.exists(self.temp_model_dir): |
| print("Cloning DiffSketcher repository...") |
| subprocess.run( |
| ["git", "clone", "https://github.com/ximinng/DiffSketcher.git", str(self.temp_model_dir)], |
| check=True |
| ) |
| |
| |
| sys.path.append(str(self.temp_model_dir.parent)) |
| |
| |
| self._install_dependencies() |
| |
| |
| self._initialize_model() |
| |
| def _install_dependencies(self): |
| """Install the required dependencies""" |
| try: |
| |
| print("Installing diffvg...") |
| subprocess.run( |
| ["pip", "install", "svgwrite", "svgpathtools", "cssutils", "numba", "torch", "torchvision", |
| "diffusers", "transformers", "accelerate", "xformers", "omegaconf", "einops", "kornia"], |
| check=True |
| ) |
| |
| |
| print("Installing CLIP...") |
| subprocess.run( |
| ["pip", "install", "git+https://github.com/openai/CLIP.git"], |
| check=True |
| ) |
| |
| |
| diffvg_dir = Path(self.temp_dir) / "diffvg" |
| diffvg_dir.mkdir(exist_ok=True) |
| with open(diffvg_dir / "__init__.py", "w") as f: |
| f.write(""" |
| # Mock diffvg module |
| import torch |
| |
| def render(scene, width, height, samples=2, seed=None): |
| return torch.zeros((height, width, 4), dtype=torch.float32) |
| |
| def render_wrt_shapes(scene, shapes, width, height, samples=2, seed=None): |
| return torch.zeros((height, width, 4), dtype=torch.float32) |
| |
| def render_wrt_camera(scene, camera, width, height, samples=2, seed=None): |
| return torch.zeros((height, width, 4), dtype=torch.float32) |
| |
| def imwrite(img, filename, gamma=2.2): |
| pass |
| |
| def save_svg(scene, filename): |
| pass |
| |
| def set_use_gpu(use_gpu): |
| pass |
| |
| def set_print_timing(print_timing): |
| pass |
| """) |
| |
| |
| sys.path.append(str(diffvg_dir.parent)) |
| |
| except Exception as e: |
| print(f"Error installing dependencies: {e}") |
| |
| def _initialize_model(self): |
| """Initialize the DiffSketcher model""" |
| try: |
| |
| from DiffSketcher.methods.painter.diffsketcher import Painter |
| from DiffSketcher.methods.diffusers_warp import init_diffusion_pipeline |
| |
| |
| self.model_initialized = True |
| print("DiffSketcher model initialized successfully") |
| except Exception as e: |
| print(f"Error initializing DiffSketcher model: {e}") |
| self.model_initialized = False |
| |
| def generate_svg(self, prompt, num_paths=10, width=512, height=512): |
| """Generate an SVG from a text prompt""" |
| print(f"Generating SVG for prompt: {prompt}") |
| |
| try: |
| |
| output_dir = Path(tempfile.mkdtemp()) |
| |
| |
| config_path = output_dir / "config.yaml" |
| with open(config_path, "w") as f: |
| f.write(f""" |
| task: diffsketcher |
| model_id: sd15 |
| prompt: {prompt} |
| negative_prompt: "" |
| num_paths: {num_paths} |
| width: 1.5 |
| image_size: {width} |
| num_iter: 500 |
| lr: 1.0 |
| sds: |
| warmup: 0 |
| grad_scale: 1.0 |
| t_range: [0.02, 0.98] |
| guidance_scale: 7.5 |
| """) |
| |
| |
| if self.model_initialized: |
| |
| try: |
| |
| from DiffSketcher.run_painterly_render import main |
| from DiffSketcher.libs.engine import merge_and_update_config |
| from omegaconf import OmegaConf |
| |
| |
| args = OmegaConf.create({ |
| "task": "diffsketcher", |
| "config": str(config_path), |
| "prompt": prompt, |
| "negative_prompt": "", |
| "num_paths": num_paths, |
| "width": 1.5, |
| "image_size": width, |
| "num_iter": 500, |
| "lr": 1.0, |
| "sds": { |
| "warmup": 0, |
| "grad_scale": 1.0, |
| "t_range": [0.02, 0.98], |
| "guidance_scale": 7.5 |
| }, |
| "seed": 42, |
| "batch_size": 1, |
| "render_batch": False, |
| "make_video": False, |
| "print_timing": False, |
| "download": True, |
| "force_download": False, |
| "resume_download": False |
| }) |
| |
| |
| args = merge_and_update_config(args) |
| main(args, None) |
| |
| |
| svg_files = list(output_dir.glob("**/*.svg")) |
| if svg_files: |
| with open(svg_files[0], "r") as f: |
| svg_content = f.read() |
| else: |
| raise FileNotFoundError("No SVG file generated") |
| |
| except Exception as e: |
| print(f"Error running DiffSketcher model: {e}") |
| |
| svg_content = self._generate_placeholder_svg(prompt, width, height) |
| else: |
| |
| svg_content = self._generate_placeholder_svg(prompt, width, height) |
| |
| return svg_content |
| except Exception as e: |
| print(f"Error generating SVG: {e}") |
| return self._generate_placeholder_svg(prompt, width, height) |
| |
| def _generate_placeholder_svg(self, prompt, width=512, height=512): |
| """Generate a placeholder SVG""" |
| svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg"> |
| <rect width="100%" height="100%" fill="#f0f0f0"/> |
| <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle">{prompt}</text> |
| </svg>""" |
| return svg_content |
| |
| def svg_to_png(self, svg_content): |
| """Convert SVG content to PNG""" |
| try: |
| png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8")) |
| return png_data |
| except Exception as e: |
| print(f"Error converting SVG to PNG: {e}") |
| |
| image = Image.new("RGB", (512, 512), color="#ff0000") |
| from PIL import ImageDraw |
| draw = ImageDraw.Draw(image) |
| draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm") |
| |
| |
| buffer = io.BytesIO() |
| image.save(buffer, format="PNG") |
| return buffer.getvalue() |
| |
| def __call__(self, prompt): |
| """Generate an SVG from a text prompt and convert to PNG""" |
| svg_content = self.generate_svg(prompt) |
| png_data = self.svg_to_png(svg_content) |
| |
| |
| image = Image.open(io.BytesIO(png_data)) |
| |
| |
| response = { |
| "svg": svg_content, |
| "svg_base64": base64.b64encode(svg_content.encode("utf-8")).decode("utf-8"), |
| "png_base64": base64.b64encode(png_data).decode("utf-8"), |
| "image": image |
| } |
| |
| return response |
| |
| def __del__(self): |
| """Clean up temporary files""" |
| if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir): |
| shutil.rmtree(self.temp_dir) |