| |
| |
|
|
| """ |
| Simplified DiffSketcher model for text-to-SVG generation. |
| """ |
|
|
| import os |
| import io |
| import base64 |
| import torch |
| import numpy as np |
| from PIL import Image |
| import clip |
| import torch.nn.functional as F |
| import xml.etree.ElementTree as ET |
| import cairosvg |
|
|
| class DiffSketcherModel: |
| def __init__(self, model_dir): |
| """Initialize the DiffSketcher model""" |
| self.model_dir = model_dir |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| |
| self.clip_model_path = os.path.join(model_dir, "ViT-B-32.pt") |
| if os.path.exists(self.clip_model_path): |
| print(f"Loading CLIP model from {self.clip_model_path}") |
| self.clip_model, _ = clip.load(self.clip_model_path, device=self.device) |
| else: |
| print(f"CLIP model not found at {self.clip_model_path}, downloading...") |
| self.clip_model, _ = clip.load("ViT-B-32", device=self.device) |
| |
| |
| self.clip_model.eval() |
| |
| print(f"DiffSketcher model initialized on device: {self.device}") |
| |
| def generate_svg(self, prompt, num_paths=10, width=512, height=512): |
| """Generate an SVG from a text prompt""" |
| print(f"Generating SVG for prompt: {prompt}") |
| |
| |
| with torch.no_grad(): |
| text_features = self.clip_model.encode_text(clip.tokenize([prompt]).to(self.device)) |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
| |
| |
| |
| svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg"> |
| <rect width="100%" height="100%" fill="#f0f0f0"/> |
| <text x="50%" y="10%" font-family="Arial" font-size="20" text-anchor="middle">Generated by DiffSketcher</text> |
| <text x="50%" y="50%" font-family="Arial" font-size="24" text-anchor="middle" font-weight="bold">{prompt}</text> |
| """ |
| |
| |
| for i in range(min(num_paths, text_features.shape[1])): |
| |
| feature_val = text_features[0, i % text_features.shape[1]].item() |
| x = (feature_val + 1) * width / 2 |
| y = ((i / num_paths) * 0.8 + 0.1) * height |
| radius = abs(feature_val) * 50 + 10 |
| hue = (feature_val + 1) * 180 |
| |
| |
| svg_content += f"""<circle cx="{x}" cy="{y}" r="{radius}" fill="hsl({hue}, 70%, 60%)" opacity="0.7" />""" |
| |
| |
| svg_content += "</svg>" |
| |
| return svg_content |
| |
| def svg_to_png(self, svg_content): |
| """Convert SVG content to PNG""" |
| try: |
| png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8")) |
| return png_data |
| except Exception as e: |
| print(f"Error converting SVG to PNG: {e}") |
| |
| image = Image.new("RGB", (512, 512), color="#ff0000") |
| from PIL import ImageDraw |
| draw = ImageDraw.Draw(image) |
| draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm") |
| |
| |
| buffer = io.BytesIO() |
| image.save(buffer, format="PNG") |
| return buffer.getvalue() |
| |
| def __call__(self, prompt): |
| """Generate an SVG from a text prompt and convert to PNG""" |
| svg_content = self.generate_svg(prompt) |
| png_data = self.svg_to_png(svg_content) |
| |
| |
| image = Image.open(io.BytesIO(png_data)) |
| |
| |
| response = { |
| "svg": svg_content, |
| "svg_base64": base64.b64encode(svg_content.encode("utf-8")).decode("utf-8"), |
| "png_base64": base64.b64encode(png_data).decode("utf-8"), |
| "image": image |
| } |
| |
| return response |