| |
| |
|
|
| """ |
| Simplified DiffSketcher implementation for Hugging Face Inference API. |
| This version doesn't rely on cloning the repository at runtime. |
| """ |
|
|
| import os |
| import io |
| import base64 |
| import torch |
| import numpy as np |
| from PIL import Image |
| import cairosvg |
| import random |
| from pathlib import Path |
|
|
| class SimplifiedDiffSketcher: |
| def __init__(self, model_dir): |
| """Initialize the simplified DiffSketcher model""" |
| self.model_dir = model_dir |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Initializing simplified DiffSketcher on device: {self.device}") |
| |
| |
| try: |
| import clip |
| self.clip_model, _ = clip.load("ViT-B-32", device=self.device) |
| self.clip_available = True |
| print("CLIP model loaded successfully") |
| except Exception as e: |
| print(f"Error loading CLIP model: {e}") |
| self.clip_available = False |
| |
| def generate_svg(self, prompt, num_paths=20, width=512, height=512): |
| """Generate an SVG from a text prompt""" |
| print(f"Generating SVG for prompt: {prompt}") |
| |
| |
| if self.clip_available: |
| try: |
| import clip |
| with torch.no_grad(): |
| text = clip.tokenize([prompt]).to(self.device) |
| text_features = self.clip_model.encode_text(text) |
| text_features = text_features.cpu().numpy()[0] |
| |
| text_features = text_features / np.linalg.norm(text_features) |
| except Exception as e: |
| print(f"Error encoding prompt with CLIP: {e}") |
| text_features = np.random.randn(512) |
| else: |
| |
| text_features = np.random.randn(512) |
| |
| |
| svg_content = self._generate_car_svg(prompt, text_features, num_paths, width, height) |
| |
| return svg_content |
| |
| def _generate_car_svg(self, prompt, features, num_paths=20, width=512, height=512): |
| """Generate a car-like SVG based on the prompt and features""" |
| |
| svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg"> |
| <rect width="100%" height="100%" fill="#f8f8f8"/> |
| """ |
| |
| |
| car_color_hue = int((features[0] + 1) * 180) % 360 |
| car_size = 0.6 + 0.2 * features[1] |
| car_style = int(abs(features[2] * 3)) % 3 |
| |
| |
| car_width = int(width * 0.7 * car_size) |
| car_height = int(height * 0.3 * car_size) |
| car_x = (width - car_width) // 2 |
| car_y = height // 2 |
| |
| |
| if car_style == 0: |
| |
| svg_content += f"""<rect x="{car_x}" y="{car_y}" width="{car_width}" height="{car_height}" |
| rx="20" ry="20" fill="hsl({car_color_hue}, 80%, 50%)" stroke="black" stroke-width="2" />""" |
| |
| |
| windshield_width = car_width * 0.7 |
| windshield_height = car_height * 0.5 |
| windshield_x = car_x + (car_width - windshield_width) // 2 |
| windshield_y = car_y - windshield_height * 0.3 |
| svg_content += f"""<rect x="{windshield_x}" y="{windshield_y}" width="{windshield_width}" height="{windshield_height}" |
| rx="10" ry="10" fill="#a8d8ff" stroke="black" stroke-width="1" />""" |
| |
| |
| wheel_radius = car_height * 0.4 |
| wheel_y = car_y + car_height * 0.8 |
| svg_content += f"""<circle cx="{car_x + car_width * 0.2}" cy="{wheel_y}" r="{wheel_radius}" fill="black" />""" |
| svg_content += f"""<circle cx="{car_x + car_width * 0.8}" cy="{wheel_y}" r="{wheel_radius}" fill="black" />""" |
| svg_content += f"""<circle cx="{car_x + car_width * 0.2}" cy="{wheel_y}" r="{wheel_radius * 0.6}" fill="#444" />""" |
| svg_content += f"""<circle cx="{car_x + car_width * 0.8}" cy="{wheel_y}" r="{wheel_radius * 0.6}" fill="#444" />""" |
| |
| elif car_style == 1: |
| |
| svg_content += f"""<rect x="{car_x}" y="{car_y - car_height * 0.3}" width="{car_width}" height="{car_height * 1.3}" |
| rx="15" ry="15" fill="hsl({car_color_hue}, 80%, 50%)" stroke="black" stroke-width="2" />""" |
| |
| |
| windshield_width = car_width * 0.6 |
| windshield_height = car_height * 0.6 |
| windshield_x = car_x + (car_width - windshield_width) // 2 |
| windshield_y = car_y - car_height * 0.2 |
| svg_content += f"""<rect x="{windshield_x}" y="{windshield_y}" width="{windshield_width}" height="{windshield_height}" |
| rx="8" ry="8" fill="#a8d8ff" stroke="black" stroke-width="1" />""" |
| |
| |
| wheel_radius = car_height * 0.45 |
| wheel_y = car_y + car_height * 0.7 |
| svg_content += f"""<circle cx="{car_x + car_width * 0.2}" cy="{wheel_y}" r="{wheel_radius}" fill="black" />""" |
| svg_content += f"""<circle cx="{car_x + car_width * 0.8}" cy="{wheel_y}" r="{wheel_radius}" fill="black" />""" |
| svg_content += f"""<circle cx="{car_x + car_width * 0.2}" cy="{wheel_y}" r="{wheel_radius * 0.6}" fill="#444" />""" |
| svg_content += f"""<circle cx="{car_x + car_width * 0.8}" cy="{wheel_y}" r="{wheel_radius * 0.6}" fill="#444" />""" |
| |
| else: |
| |
| svg_content += f"""<path d="M {car_x} {car_y + car_height * 0.5} |
| C {car_x + car_width * 0.1} {car_y - car_height * 0.2}, |
| {car_x + car_width * 0.3} {car_y - car_height * 0.3}, |
| {car_x + car_width * 0.5} {car_y - car_height * 0.2} |
| S {car_x + car_width * 0.9} {car_y}, |
| {car_x + car_width} {car_y + car_height * 0.3} |
| L {car_x + car_width} {car_y + car_height * 0.7} |
| C {car_x + car_width * 0.9} {car_y + car_height}, |
| {car_x + car_width * 0.1} {car_y + car_height}, |
| {car_x} {car_y + car_height * 0.7} Z" |
| fill="hsl({car_color_hue}, 90%, 45%)" stroke="black" stroke-width="2" />""" |
| |
| |
| windshield_width = car_width * 0.4 |
| windshield_x = car_x + car_width * 0.3 |
| windshield_y = car_y - car_height * 0.1 |
| svg_content += f"""<path d="M {windshield_x} {windshield_y} |
| C {windshield_x + windshield_width * 0.1} {windshield_y - car_height * 0.15}, |
| {windshield_x + windshield_width * 0.9} {windshield_y - car_height * 0.15}, |
| {windshield_x + windshield_width} {windshield_y} Z" |
| fill="#a8d8ff" stroke="black" stroke-width="1" />""" |
| |
| |
| wheel_radius = car_height * 0.35 |
| wheel_y = car_y + car_height * 0.7 |
| svg_content += f"""<ellipse cx="{car_x + car_width * 0.2}" cy="{wheel_y}" rx="{wheel_radius * 1.2}" ry="{wheel_radius}" fill="black" />""" |
| svg_content += f"""<ellipse cx="{car_x + car_width * 0.8}" cy="{wheel_y}" rx="{wheel_radius * 1.2}" ry="{wheel_radius}" fill="black" />""" |
| svg_content += f"""<ellipse cx="{car_x + car_width * 0.2}" cy="{wheel_y}" rx="{wheel_radius * 0.7}" ry="{wheel_radius * 0.6}" fill="#444" />""" |
| svg_content += f"""<ellipse cx="{car_x + car_width * 0.8}" cy="{wheel_y}" rx="{wheel_radius * 0.7}" ry="{wheel_radius * 0.6}" fill="#444" />""" |
| |
| |
| headlight_radius = car_width * 0.05 |
| headlight_y = car_y + car_height * 0.3 |
| svg_content += f"""<circle cx="{car_x + car_width * 0.1}" cy="{headlight_y}" r="{headlight_radius}" fill="yellow" stroke="black" stroke-width="1" />""" |
| svg_content += f"""<circle cx="{car_x + car_width * 0.9}" cy="{headlight_y}" r="{headlight_radius}" fill="yellow" stroke="black" stroke-width="1" />""" |
| |
| |
| for i in range(min(10, len(features))): |
| feature_val = features[i % len(features)] |
| x = car_x + car_width * ((i / 10) * 0.8 + 0.1) |
| y = car_y + car_height * ((feature_val + 1) / 4) |
| size = car_width * 0.03 * abs(feature_val) |
| svg_content += f"""<circle cx="{x}" cy="{y}" r="{size}" fill="rgba(0,0,0,0.2)" />""" |
| |
| |
| svg_content += f"""<text x="{width/2}" y="{height - 20}" font-family="Arial" font-size="12" text-anchor="middle">{prompt}</text>""" |
| |
| |
| svg_content += "</svg>" |
| |
| return svg_content |
| |
| def svg_to_png(self, svg_content): |
| """Convert SVG content to PNG""" |
| try: |
| png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8")) |
| return png_data |
| except Exception as e: |
| print(f"Error converting SVG to PNG: {e}") |
| |
| image = Image.new("RGB", (512, 512), color="#ff0000") |
| from PIL import ImageDraw |
| draw = ImageDraw.Draw(image) |
| draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm") |
| |
| |
| buffer = io.BytesIO() |
| image.save(buffer, format="PNG") |
| return buffer.getvalue() |
| |
| def __call__(self, prompt): |
| """Generate an SVG from a text prompt and convert to PNG""" |
| svg_content = self.generate_svg(prompt) |
| png_data = self.svg_to_png(svg_content) |
| |
| |
| image = Image.open(io.BytesIO(png_data)) |
| |
| |
| response = { |
| "svg": svg_content, |
| "svg_base64": base64.b64encode(svg_content.encode("utf-8")).decode("utf-8"), |
| "png_base64": base64.b64encode(png_data).decode("utf-8"), |
| "image": image |
| } |
| |
| return response |