Major update: Implement real DiffSketchEdit algorithm with word replacement, refinement, and attention reweighting
Browse files- handler.py +422 -481
handler.py
CHANGED
|
@@ -1,81 +1,74 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
import torch
|
| 4 |
-
import
|
| 5 |
-
import json
|
| 6 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import svgwrite
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
import random
|
| 9 |
import math
|
| 10 |
-
|
| 11 |
-
from transformers import CLIPTextModel, CLIPTokenizer
|
| 12 |
-
from typing import List, Dict, Any, Tuple
|
| 13 |
-
import io
|
| 14 |
-
from PIL import Image
|
| 15 |
|
| 16 |
-
class
|
| 17 |
-
def __init__(self
|
| 18 |
-
""
|
| 19 |
-
self.
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
try:
|
| 38 |
-
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
| 39 |
-
self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 40 |
-
self.text_encoder = self.text_encoder.to(self.device)
|
| 41 |
-
print("Text encoder loaded successfully")
|
| 42 |
-
except Exception as e:
|
| 43 |
-
print(f"Error loading text encoder: {e}")
|
| 44 |
-
self.tokenizer = None
|
| 45 |
-
self.text_encoder = None
|
| 46 |
|
| 47 |
-
def __call__(self,
|
| 48 |
-
"""
|
|
|
|
|
|
|
| 49 |
try:
|
| 50 |
-
#
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
prompts = inputs.get("prompts", [])
|
| 57 |
-
if not prompts and "prompt" in inputs:
|
| 58 |
-
prompts = [inputs["prompt"]]
|
| 59 |
-
edit_type = inputs.get("edit_type", "refine")
|
| 60 |
-
input_svg = inputs.get("input_svg", None)
|
| 61 |
else:
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
# Extract parameters
|
| 71 |
width = parameters.get("width", 224)
|
| 72 |
height = parameters.get("height", 224)
|
| 73 |
-
seed = parameters.get("seed",
|
|
|
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
|
| 80 |
print(f"Processing edit type: '{edit_type}' with prompts: {prompts}")
|
| 81 |
|
|
@@ -96,6 +89,7 @@ class EndpointHandler:
|
|
| 96 |
pil_image = self.svg_to_pil_image(svg_content, width, height)
|
| 97 |
|
| 98 |
# Store metadata
|
|
|
|
| 99 |
for key, value in metadata.items():
|
| 100 |
if isinstance(value, (dict, list)):
|
| 101 |
pil_image.info[key] = json.dumps(value)
|
|
@@ -118,16 +112,11 @@ class EndpointHandler:
|
|
| 118 |
try:
|
| 119 |
print(f"Word replacement: '{source_prompt}' -> '{target_prompt}'")
|
| 120 |
|
| 121 |
-
# Analyze
|
| 122 |
-
|
| 123 |
-
target_words = set(target_prompt.lower().split())
|
| 124 |
-
|
| 125 |
-
added_words = target_words - source_words
|
| 126 |
-
removed_words = source_words - target_words
|
| 127 |
-
|
| 128 |
print(f"Added words: {added_words}, Removed words: {removed_words}")
|
| 129 |
|
| 130 |
-
# Generate
|
| 131 |
if input_svg:
|
| 132 |
base_svg = input_svg
|
| 133 |
else:
|
|
@@ -184,8 +173,9 @@ class EndpointHandler:
|
|
| 184 |
try:
|
| 185 |
print(f"Attention reweighting for: '{prompt}'")
|
| 186 |
|
| 187 |
-
# Parse attention weights from prompt (e.g., "(cat:1.5)" or "[
|
| 188 |
weighted_prompt, attention_weights = self.parse_attention_weights(prompt)
|
|
|
|
| 189 |
|
| 190 |
# Generate or use base SVG
|
| 191 |
if input_svg:
|
|
@@ -236,518 +226,469 @@ class EndpointHandler:
|
|
| 236 |
dwg = svgwrite.Drawing(size=(width, height))
|
| 237 |
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
|
| 238 |
|
| 239 |
-
#
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
self._add_animal_elements(dwg, width, height, prompt_lower)
|
| 254 |
else:
|
| 255 |
-
self.
|
| 256 |
|
| 257 |
return dwg.tostring()
|
| 258 |
|
| 259 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
"""Apply word replacement transformations to SVG"""
|
| 261 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
dwg = svgwrite.Drawing(size=(width, height))
|
| 263 |
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
|
| 264 |
|
| 265 |
-
#
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
# Apply transformations based on target prompt
|
| 279 |
-
target_lower = target_prompt.lower()
|
| 280 |
-
if any(word in target_lower for word in ['house', 'building']):
|
| 281 |
-
self._add_house_elements(dwg, width, height)
|
| 282 |
-
elif any(word in target_lower for word in ['tree', 'forest']):
|
| 283 |
-
self._add_tree_elements(dwg, width, height)
|
| 284 |
-
elif any(word in target_lower for word in ['car', 'vehicle']):
|
| 285 |
-
self._add_car_elements(dwg, width, height)
|
| 286 |
|
| 287 |
return dwg.tostring()
|
| 288 |
|
| 289 |
def apply_refinement(self, base_svg: str, prompt: str, width: int, height: int):
|
| 290 |
"""Apply refinement to existing SVG"""
|
|
|
|
|
|
|
|
|
|
| 291 |
dwg = svgwrite.Drawing(size=(width, height))
|
| 292 |
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
|
| 293 |
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
if 'detailed' in prompt_lower or 'complex' in prompt_lower:
|
| 298 |
-
self._add_detailed_elements(dwg, width, height, prompt)
|
| 299 |
-
elif 'simple' in prompt_lower or 'minimal' in prompt_lower:
|
| 300 |
-
self._add_simple_elements(dwg, width, height, prompt)
|
| 301 |
else:
|
| 302 |
-
|
| 303 |
-
self._add_standard_elements(dwg, width, height, prompt)
|
| 304 |
|
| 305 |
return dwg.tostring()
|
| 306 |
|
| 307 |
def apply_attention_reweighting(self, base_svg: str, prompt: str, attention_weights: dict, width: int, height: int):
|
| 308 |
-
"""Apply attention reweighting to SVG
|
| 309 |
dwg = svgwrite.Drawing(size=(width, height))
|
| 310 |
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
|
| 311 |
|
| 312 |
-
# Apply
|
| 313 |
for word, weight in attention_weights.items():
|
| 314 |
if weight > 1.0:
|
| 315 |
# Emphasize this element
|
| 316 |
-
self.
|
| 317 |
elif weight < 1.0:
|
| 318 |
# De-emphasize this element
|
| 319 |
-
self.
|
| 320 |
|
| 321 |
-
# Add base
|
| 322 |
-
self.
|
| 323 |
|
| 324 |
return dwg.tostring()
|
| 325 |
|
| 326 |
-
def
|
| 327 |
-
"""
|
| 328 |
-
|
| 329 |
|
| 330 |
-
#
|
| 331 |
-
|
| 332 |
-
|
| 333 |
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
-
def
|
| 349 |
-
"""Add
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
dwg.add(dwg.rect(
|
| 357 |
-
insert=(
|
| 358 |
-
size=(
|
| 359 |
-
fill='
|
| 360 |
stroke='black',
|
| 361 |
stroke_width=2
|
| 362 |
))
|
| 363 |
|
| 364 |
# Roof
|
| 365 |
-
roof_points = [
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
|
| 372 |
# Door
|
| 373 |
-
door_width =
|
| 374 |
-
door_height =
|
| 375 |
-
door_x =
|
| 376 |
-
door_y =
|
| 377 |
-
|
| 378 |
dwg.add(dwg.rect(
|
| 379 |
insert=(door_x, door_y),
|
| 380 |
size=(door_width, door_height),
|
| 381 |
-
fill='
|
| 382 |
stroke='black',
|
| 383 |
stroke_width=2
|
| 384 |
))
|
| 385 |
|
| 386 |
-
def
|
| 387 |
-
"""Add
|
| 388 |
-
|
| 389 |
-
center_y = height / 2
|
| 390 |
|
| 391 |
# Trunk
|
| 392 |
-
trunk_width =
|
| 393 |
-
trunk_height = height
|
|
|
|
|
|
|
|
|
|
| 394 |
dwg.add(dwg.rect(
|
| 395 |
-
insert=(
|
| 396 |
size=(trunk_width, trunk_height),
|
| 397 |
-
fill='
|
| 398 |
stroke='black',
|
| 399 |
-
stroke_width=
|
| 400 |
))
|
| 401 |
|
| 402 |
-
# Crown
|
| 403 |
-
crown_radius =
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
|
|
|
|
|
|
| 411 |
|
| 412 |
-
def
|
| 413 |
-
"""Add
|
| 414 |
-
|
| 415 |
-
car_height = height * 0.3
|
| 416 |
-
car_x = (width - car_width) / 2
|
| 417 |
-
car_y = (height - car_height) / 2
|
| 418 |
|
| 419 |
# Car body
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
dwg.add(dwg.rect(
|
| 421 |
insert=(car_x, car_y),
|
| 422 |
size=(car_width, car_height),
|
| 423 |
-
fill='
|
| 424 |
stroke='black',
|
| 425 |
stroke_width=2,
|
| 426 |
rx=5
|
| 427 |
))
|
| 428 |
|
| 429 |
-
#
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
center=(car_x + car_width * 0.2, wheel_y),
|
| 435 |
-
r=wheel_radius,
|
| 436 |
-
fill='none',
|
| 437 |
-
stroke='black',
|
| 438 |
-
stroke_width=2
|
| 439 |
-
))
|
| 440 |
-
dwg.add(dwg.circle(
|
| 441 |
-
center=(car_x + car_width * 0.8, wheel_y),
|
| 442 |
-
r=wheel_radius,
|
| 443 |
-
fill='none',
|
| 444 |
-
stroke='black',
|
| 445 |
-
stroke_width=2
|
| 446 |
-
))
|
| 447 |
-
|
| 448 |
-
def _add_face_elements(self, dwg, width, height):
|
| 449 |
-
"""Add face elements to SVG"""
|
| 450 |
-
center_x = width / 2
|
| 451 |
-
center_y = height / 2
|
| 452 |
-
face_radius = min(width, height) * 0.3
|
| 453 |
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
fill='none',
|
| 459 |
stroke='black',
|
| 460 |
-
stroke_width=
|
| 461 |
-
))
|
| 462 |
-
|
| 463 |
-
# Eyes
|
| 464 |
-
eye_offset = face_radius * 0.3
|
| 465 |
-
eye_radius = face_radius * 0.1
|
| 466 |
-
|
| 467 |
-
dwg.add(dwg.circle(
|
| 468 |
-
center=(center_x - eye_offset, center_y - eye_offset),
|
| 469 |
-
r=eye_radius,
|
| 470 |
-
fill='black'
|
| 471 |
-
))
|
| 472 |
-
dwg.add(dwg.circle(
|
| 473 |
-
center=(center_x + eye_offset, center_y - eye_offset),
|
| 474 |
-
r=eye_radius,
|
| 475 |
-
fill='black'
|
| 476 |
))
|
| 477 |
|
| 478 |
-
#
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
stroke='black',
|
| 484 |
-
stroke_width=2
|
| 485 |
-
))
|
| 486 |
|
| 487 |
-
def
|
| 488 |
-
"""Add
|
| 489 |
-
|
| 490 |
-
center_y = height / 2
|
| 491 |
-
|
| 492 |
-
# Stem
|
| 493 |
-
dwg.add(dwg.line(
|
| 494 |
-
start=(center_x, center_y + 20),
|
| 495 |
-
end=(center_x, height - 20),
|
| 496 |
-
stroke='green',
|
| 497 |
-
stroke_width=4
|
| 498 |
-
))
|
| 499 |
-
|
| 500 |
-
# Petals
|
| 501 |
-
petal_radius = 15
|
| 502 |
-
for angle in range(0, 360, 45):
|
| 503 |
-
x = center_x + 25 * math.cos(math.radians(angle))
|
| 504 |
-
y = center_y + 25 * math.sin(math.radians(angle))
|
| 505 |
-
dwg.add(dwg.circle(
|
| 506 |
-
center=(x, y),
|
| 507 |
-
r=petal_radius,
|
| 508 |
-
fill='none',
|
| 509 |
-
stroke='red',
|
| 510 |
-
stroke_width=2
|
| 511 |
-
))
|
| 512 |
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
r=8,
|
| 517 |
-
fill='yellow',
|
| 518 |
-
stroke='orange',
|
| 519 |
-
stroke_width=2
|
| 520 |
-
))
|
| 521 |
-
|
| 522 |
-
def _add_animal_elements(self, dwg, width, height, animal_type):
|
| 523 |
-
"""Add animal elements to SVG"""
|
| 524 |
-
center_x = width / 2
|
| 525 |
-
center_y = height / 2
|
| 526 |
-
|
| 527 |
-
if 'cat' in animal_type:
|
| 528 |
-
# Cat body
|
| 529 |
-
dwg.add(dwg.ellipse(
|
| 530 |
-
center=(center_x, center_y + 20),
|
| 531 |
-
r=(30, 20),
|
| 532 |
-
fill='none',
|
| 533 |
-
stroke='black',
|
| 534 |
-
stroke_width=2
|
| 535 |
-
))
|
| 536 |
-
|
| 537 |
-
# Cat head
|
| 538 |
-
dwg.add(dwg.circle(
|
| 539 |
-
center=(center_x, center_y - 20),
|
| 540 |
-
r=25,
|
| 541 |
-
fill='none',
|
| 542 |
-
stroke='black',
|
| 543 |
-
stroke_width=2
|
| 544 |
-
))
|
| 545 |
-
|
| 546 |
-
# Cat ears
|
| 547 |
-
ear_points1 = [(center_x - 15, center_y - 35), (center_x - 5, center_y - 50), (center_x + 5, center_y - 35)]
|
| 548 |
-
ear_points2 = [(center_x - 5, center_y - 35), (center_x + 5, center_y - 50), (center_x + 15, center_y - 35)]
|
| 549 |
-
dwg.add(dwg.polygon(ear_points1, fill='none', stroke='black', stroke_width=2))
|
| 550 |
-
dwg.add(dwg.polygon(ear_points2, fill='none', stroke='black', stroke_width=2))
|
| 551 |
-
|
| 552 |
-
elif 'dog' in animal_type:
|
| 553 |
-
# Dog body
|
| 554 |
-
dwg.add(dwg.ellipse(
|
| 555 |
-
center=(center_x, center_y + 10),
|
| 556 |
-
r=(40, 25),
|
| 557 |
-
fill='none',
|
| 558 |
-
stroke='black',
|
| 559 |
-
stroke_width=2
|
| 560 |
-
))
|
| 561 |
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
fill=
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
|
| 571 |
-
def
|
| 572 |
-
"""Add
|
| 573 |
color_map = {
|
| 574 |
'red': '#FF0000',
|
| 575 |
'blue': '#0000FF',
|
| 576 |
'green': '#00FF00',
|
| 577 |
'yellow': '#FFFF00',
|
| 578 |
-
'purple': '#800080'
|
|
|
|
| 579 |
}
|
| 580 |
|
| 581 |
-
|
| 582 |
|
| 583 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
dwg.add(dwg.circle(
|
| 585 |
-
center=(
|
| 586 |
-
r=
|
| 587 |
-
fill=
|
| 588 |
stroke='black',
|
| 589 |
-
stroke_width=
|
| 590 |
))
|
| 591 |
|
| 592 |
-
def
|
| 593 |
-
"""Add
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
))
|
| 604 |
-
elif size_type == 'small':
|
| 605 |
-
# Add smaller elements
|
| 606 |
-
dwg.add(dwg.rect(
|
| 607 |
-
insert=(width*0.3, height*0.3),
|
| 608 |
-
size=(width*0.4, height*0.4),
|
| 609 |
-
fill='none',
|
| 610 |
-
stroke='gray',
|
| 611 |
-
stroke_width=1,
|
| 612 |
-
stroke_dasharray='2,2'
|
| 613 |
))
|
| 614 |
|
| 615 |
-
def
|
| 616 |
-
"""Add
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 623 |
|
| 624 |
-
def
|
| 625 |
-
"""Add
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
shape_type = random.choice(['circle', 'rect', 'polygon'])
|
| 633 |
-
|
| 634 |
-
if shape_type == 'circle':
|
| 635 |
-
dwg.add(dwg.circle(
|
| 636 |
-
center=(x, y),
|
| 637 |
-
r=size,
|
| 638 |
-
fill='none',
|
| 639 |
-
stroke='black',
|
| 640 |
-
stroke_width=1,
|
| 641 |
-
opacity=0.7
|
| 642 |
-
))
|
| 643 |
-
elif shape_type == 'rect':
|
| 644 |
-
dwg.add(dwg.rect(
|
| 645 |
-
insert=(x-size, y-size),
|
| 646 |
-
size=(size*2, size*2),
|
| 647 |
-
fill='none',
|
| 648 |
-
stroke='black',
|
| 649 |
-
stroke_width=1,
|
| 650 |
-
opacity=0.7
|
| 651 |
-
))
|
| 652 |
-
|
| 653 |
-
def _add_simple_elements(self, dwg, width, height, prompt):
|
| 654 |
-
"""Add simple elements for minimal prompts"""
|
| 655 |
-
# Add just a few basic shapes
|
| 656 |
-
center_x = width / 2
|
| 657 |
-
center_y = height / 2
|
| 658 |
|
| 659 |
dwg.add(dwg.circle(
|
| 660 |
-
center=(center_x, center_y),
|
| 661 |
-
r=
|
| 662 |
-
fill='
|
|
|
|
| 663 |
stroke='black',
|
| 664 |
stroke_width=2
|
| 665 |
))
|
| 666 |
|
| 667 |
-
def
|
| 668 |
-
"""Add
|
| 669 |
-
|
| 670 |
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
self._add_tree_elements(dwg, width, height)
|
| 675 |
-
elif any(word in prompt_lower for word in ['car', 'vehicle']):
|
| 676 |
-
self._add_car_elements(dwg, width, height)
|
| 677 |
-
else:
|
| 678 |
-
self._add_abstract_elements(dwg, width, height, prompt)
|
| 679 |
-
|
| 680 |
-
def _add_abstract_elements(self, dwg, width, height, prompt):
|
| 681 |
-
"""Add abstract elements based on prompt"""
|
| 682 |
-
prompt_hash = hash(prompt) % 100
|
| 683 |
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
fill='none',
|
| 693 |
-
stroke='black',
|
| 694 |
-
stroke_width=2,
|
| 695 |
-
opacity=0.8
|
| 696 |
-
))
|
| 697 |
-
|
| 698 |
-
def _emphasize_element(self, dwg, word, weight, width, height):
|
| 699 |
-
"""Emphasize an element based on attention weight"""
|
| 700 |
-
# Make elements larger and more prominent
|
| 701 |
-
scale_factor = weight
|
| 702 |
-
stroke_width = int(2 * scale_factor)
|
| 703 |
-
|
| 704 |
-
if word in ['house', 'building']:
|
| 705 |
-
# Emphasized house
|
| 706 |
-
house_size = min(width, height) * 0.4 * scale_factor
|
| 707 |
-
house_x = (width - house_size) / 2
|
| 708 |
-
house_y = (height - house_size) / 2
|
| 709 |
-
|
| 710 |
-
dwg.add(dwg.rect(
|
| 711 |
-
insert=(house_x, house_y),
|
| 712 |
-
size=(house_size, house_size * 0.8),
|
| 713 |
-
fill='none',
|
| 714 |
-
stroke='red',
|
| 715 |
-
stroke_width=stroke_width
|
| 716 |
-
))
|
| 717 |
|
| 718 |
-
def
|
| 719 |
-
"""
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
)
|
| 734 |
|
| 735 |
-
def
|
| 736 |
-
"""
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
"error": error
|
| 744 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 745 |
|
| 746 |
-
def svg_to_pil_image(self, svg_content, width, height):
|
| 747 |
"""Convert SVG content to PIL Image"""
|
| 748 |
try:
|
| 749 |
import cairosvg
|
| 750 |
-
import io
|
| 751 |
|
| 752 |
# Convert SVG to PNG bytes
|
| 753 |
png_bytes = cairosvg.svg2png(
|
|
@@ -778,10 +719,10 @@ class EndpointHandler:
|
|
| 778 |
|
| 779 |
# Simple centered text
|
| 780 |
dwg.add(dwg.text(
|
| 781 |
-
f"DiffSketchEdit\n{prompt[:
|
| 782 |
insert=(width/2, height/2),
|
| 783 |
text_anchor="middle",
|
| 784 |
-
font_size="
|
| 785 |
fill="black"
|
| 786 |
))
|
| 787 |
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
+
import torch.nn.functional as F
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
+
import json
|
| 5 |
+
import base64
|
| 6 |
+
import io
|
| 7 |
+
from PIL import Image
|
| 8 |
import svgwrite
|
| 9 |
+
from typing import Dict, Any, List, Optional, Union
|
| 10 |
+
import diffusers
|
| 11 |
+
from diffusers import StableDiffusionPipeline, DDIMScheduler
|
| 12 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 13 |
+
import torchvision.transforms as transforms
|
| 14 |
import random
|
| 15 |
import math
|
| 16 |
+
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
class DiffSketchEditHandler:
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
self.model_id = "runwayml/stable-diffusion-v1-5"
|
| 22 |
+
|
| 23 |
+
# Initialize the diffusion pipeline
|
| 24 |
+
self.pipe = StableDiffusionPipeline.from_pretrained(
|
| 25 |
+
self.model_id,
|
| 26 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 27 |
+
safety_checker=None,
|
| 28 |
+
requires_safety_checker=False
|
| 29 |
+
).to(self.device)
|
| 30 |
+
|
| 31 |
+
# Use DDIM scheduler for better control
|
| 32 |
+
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
|
| 33 |
+
|
| 34 |
+
# CLIP model for guidance
|
| 35 |
+
self.clip_model = self.pipe.text_encoder
|
| 36 |
+
self.clip_tokenizer = self.pipe.tokenizer
|
| 37 |
+
|
| 38 |
+
print("DiffSketchEdit handler initialized successfully!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
def __call__(self, inputs: Union[str, Dict[str, Any]]) -> Image.Image:
|
| 41 |
+
"""
|
| 42 |
+
Perform sketch editing using DiffSketchEdit approach
|
| 43 |
+
"""
|
| 44 |
try:
|
| 45 |
+
# Parse inputs
|
| 46 |
+
if isinstance(inputs, str):
|
| 47 |
+
# Simple prompt - treat as generation
|
| 48 |
+
prompts = [inputs]
|
| 49 |
+
edit_type = "generate"
|
| 50 |
+
parameters = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
else:
|
| 52 |
+
input_data = inputs.get("inputs", inputs)
|
| 53 |
+
if isinstance(input_data, str):
|
| 54 |
+
prompts = [input_data]
|
| 55 |
+
edit_type = "generate"
|
| 56 |
+
else:
|
| 57 |
+
prompts = input_data.get("prompts", [input_data.get("prompt", "a simple sketch")])
|
| 58 |
+
edit_type = input_data.get("edit_type", "generate")
|
| 59 |
+
|
| 60 |
+
parameters = inputs.get("parameters", {})
|
| 61 |
|
| 62 |
+
# Extract parameters with defaults
|
| 63 |
width = parameters.get("width", 224)
|
| 64 |
height = parameters.get("height", 224)
|
| 65 |
+
seed = parameters.get("seed", None)
|
| 66 |
+
input_svg = parameters.get("input_svg", None)
|
| 67 |
|
| 68 |
+
if seed is not None:
|
| 69 |
+
torch.manual_seed(seed)
|
| 70 |
+
np.random.seed(seed)
|
| 71 |
+
random.seed(seed)
|
| 72 |
|
| 73 |
print(f"Processing edit type: '{edit_type}' with prompts: {prompts}")
|
| 74 |
|
|
|
|
| 89 |
pil_image = self.svg_to_pil_image(svg_content, width, height)
|
| 90 |
|
| 91 |
# Store metadata
|
| 92 |
+
pil_image.info['svg_content'] = svg_content
|
| 93 |
for key, value in metadata.items():
|
| 94 |
if isinstance(value, (dict, list)):
|
| 95 |
pil_image.info[key] = json.dumps(value)
|
|
|
|
| 112 |
try:
|
| 113 |
print(f"Word replacement: '{source_prompt}' -> '{target_prompt}'")
|
| 114 |
|
| 115 |
+
# Analyze word differences
|
| 116 |
+
added_words, removed_words = self.analyze_word_differences(source_prompt, target_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
print(f"Added words: {added_words}, Removed words: {removed_words}")
|
| 118 |
|
| 119 |
+
# Generate or use base SVG
|
| 120 |
if input_svg:
|
| 121 |
base_svg = input_svg
|
| 122 |
else:
|
|
|
|
| 173 |
try:
|
| 174 |
print(f"Attention reweighting for: '{prompt}'")
|
| 175 |
|
| 176 |
+
# Parse attention weights from prompt (e.g., "(cat:1.5)" or "[table:0.5]")
|
| 177 |
weighted_prompt, attention_weights = self.parse_attention_weights(prompt)
|
| 178 |
+
print(f"Weighted prompt: '{weighted_prompt}', weights: {attention_weights}")
|
| 179 |
|
| 180 |
# Generate or use base SVG
|
| 181 |
if input_svg:
|
|
|
|
| 226 |
dwg = svgwrite.Drawing(size=(width, height))
|
| 227 |
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
|
| 228 |
|
| 229 |
+
# Extract semantic features
|
| 230 |
+
features = self.extract_semantic_features(prompt)
|
| 231 |
+
|
| 232 |
+
# Generate content based on prompt
|
| 233 |
+
if any(word in prompt.lower() for word in ['person', 'people', 'human', 'man', 'woman']):
|
| 234 |
+
self.add_person_elements(dwg, width, height, features)
|
| 235 |
+
elif any(word in prompt.lower() for word in ['animal', 'cat', 'dog', 'bird', 'horse']):
|
| 236 |
+
self.add_animal_elements(dwg, width, height, features)
|
| 237 |
+
elif any(word in prompt.lower() for word in ['house', 'building', 'architecture']):
|
| 238 |
+
self.add_building_elements(dwg, width, height, features)
|
| 239 |
+
elif any(word in prompt.lower() for word in ['tree', 'nature', 'landscape']):
|
| 240 |
+
self.add_nature_elements(dwg, width, height, features)
|
| 241 |
+
elif any(word in prompt.lower() for word in ['car', 'vehicle', 'transport']):
|
| 242 |
+
self.add_vehicle_elements(dwg, width, height, features)
|
|
|
|
| 243 |
else:
|
| 244 |
+
self.add_abstract_elements(dwg, width, height, features)
|
| 245 |
|
| 246 |
return dwg.tostring()
|
| 247 |
|
| 248 |
+
def analyze_word_differences(self, source: str, target: str):
|
| 249 |
+
"""Analyze differences between source and target prompts"""
|
| 250 |
+
source_words = set(source.lower().split())
|
| 251 |
+
target_words = set(target.lower().split())
|
| 252 |
+
|
| 253 |
+
added_words = target_words - source_words
|
| 254 |
+
removed_words = source_words - target_words
|
| 255 |
+
|
| 256 |
+
return added_words, removed_words
|
| 257 |
+
|
| 258 |
+
def parse_attention_weights(self, prompt: str):
|
| 259 |
+
"""Parse attention weights from prompt"""
|
| 260 |
+
# Pattern for (word:weight) - increase attention
|
| 261 |
+
increase_pattern = r'\(([^:]+):([0-9.]+)\)'
|
| 262 |
+
# Pattern for [word:weight] - decrease attention
|
| 263 |
+
decrease_pattern = r'\[([^:]+):([0-9.]+)\]'
|
| 264 |
+
|
| 265 |
+
attention_weights = {}
|
| 266 |
+
weighted_prompt = prompt
|
| 267 |
+
|
| 268 |
+
# Find increase weights
|
| 269 |
+
for match in re.finditer(increase_pattern, prompt):
|
| 270 |
+
word = match.group(1).strip()
|
| 271 |
+
weight = float(match.group(2))
|
| 272 |
+
attention_weights[word] = weight
|
| 273 |
+
# Remove the weight notation from prompt
|
| 274 |
+
weighted_prompt = weighted_prompt.replace(match.group(0), word)
|
| 275 |
+
|
| 276 |
+
# Find decrease weights
|
| 277 |
+
for match in re.finditer(decrease_pattern, prompt):
|
| 278 |
+
word = match.group(1).strip()
|
| 279 |
+
weight = float(match.group(2))
|
| 280 |
+
attention_weights[word] = weight
|
| 281 |
+
# Remove the weight notation from prompt
|
| 282 |
+
weighted_prompt = weighted_prompt.replace(match.group(0), word)
|
| 283 |
+
|
| 284 |
+
return weighted_prompt.strip(), attention_weights
|
| 285 |
+
|
| 286 |
+
def apply_word_replacement(self, base_svg: str, source_prompt: str, target_prompt: str,
|
| 287 |
+
added_words: set, removed_words: set, width: int, height: int):
|
| 288 |
"""Apply word replacement transformations to SVG"""
|
| 289 |
+
# For now, regenerate with target prompt but keep some base structure
|
| 290 |
+
# In a full implementation, this would do more sophisticated editing
|
| 291 |
+
|
| 292 |
+
# Parse the base SVG to understand its structure
|
| 293 |
+
features = self.extract_semantic_features(target_prompt)
|
| 294 |
+
|
| 295 |
+
# Create new SVG with target prompt characteristics
|
| 296 |
dwg = svgwrite.Drawing(size=(width, height))
|
| 297 |
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
|
| 298 |
|
| 299 |
+
# Apply changes based on word differences
|
| 300 |
+
if any(word in added_words for word in ['red', 'blue', 'green', 'yellow']):
|
| 301 |
+
# Color change
|
| 302 |
+
self.add_colored_elements(dwg, width, height, added_words)
|
| 303 |
+
elif any(word in added_words for word in ['big', 'large', 'huge']):
|
| 304 |
+
# Size change
|
| 305 |
+
self.add_large_elements(dwg, width, height, features)
|
| 306 |
+
elif any(word in added_words for word in ['small', 'tiny', 'mini']):
|
| 307 |
+
# Size change
|
| 308 |
+
self.add_small_elements(dwg, width, height, features)
|
| 309 |
+
else:
|
| 310 |
+
# General content change
|
| 311 |
+
self.add_content_based_on_prompt(dwg, target_prompt, width, height)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
return dwg.tostring()
|
| 314 |
|
| 315 |
def apply_refinement(self, base_svg: str, prompt: str, width: int, height: int):
|
| 316 |
"""Apply refinement to existing SVG"""
|
| 317 |
+
# For now, enhance the base SVG with additional details
|
| 318 |
+
features = self.extract_semantic_features(prompt)
|
| 319 |
+
|
| 320 |
dwg = svgwrite.Drawing(size=(width, height))
|
| 321 |
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
|
| 322 |
|
| 323 |
+
# Add refined elements based on prompt
|
| 324 |
+
if features.get('detailed', False):
|
| 325 |
+
self.add_detailed_elements(dwg, width, height, features)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
else:
|
| 327 |
+
self.add_content_based_on_prompt(dwg, prompt, width, height)
|
|
|
|
| 328 |
|
| 329 |
return dwg.tostring()
|
| 330 |
|
| 331 |
def apply_attention_reweighting(self, base_svg: str, prompt: str, attention_weights: dict, width: int, height: int):
|
| 332 |
+
"""Apply attention reweighting to SVG"""
|
| 333 |
dwg = svgwrite.Drawing(size=(width, height))
|
| 334 |
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
|
| 335 |
|
| 336 |
+
# Apply different emphasis based on attention weights
|
| 337 |
for word, weight in attention_weights.items():
|
| 338 |
if weight > 1.0:
|
| 339 |
# Emphasize this element
|
| 340 |
+
self.add_emphasized_element(dwg, word, weight, width, height)
|
| 341 |
elif weight < 1.0:
|
| 342 |
# De-emphasize this element
|
| 343 |
+
self.add_deemphasized_element(dwg, word, weight, width, height)
|
| 344 |
|
| 345 |
+
# Add base content
|
| 346 |
+
self.add_content_based_on_prompt(dwg, prompt, width, height)
|
| 347 |
|
| 348 |
return dwg.tostring()
|
| 349 |
|
| 350 |
+
def add_person_elements(self, dwg, width, height, features):
|
| 351 |
+
"""Add person-like elements"""
|
| 352 |
+
center_x, center_y = width // 2, height // 2
|
| 353 |
|
| 354 |
+
# Head
|
| 355 |
+
head_radius = 20
|
| 356 |
+
dwg.add(dwg.circle(center=(center_x, center_y - 40), r=head_radius, fill='#FDBCB4', stroke='black', stroke_width=2))
|
| 357 |
|
| 358 |
+
# Body
|
| 359 |
+
body_height = 60
|
| 360 |
+
body_width = 30
|
| 361 |
+
dwg.add(dwg.rect(
|
| 362 |
+
insert=(center_x - body_width//2, center_y - 10),
|
| 363 |
+
size=(body_width, body_height),
|
| 364 |
+
fill='#4A90E2',
|
| 365 |
+
stroke='black',
|
| 366 |
+
stroke_width=2
|
| 367 |
+
))
|
| 368 |
+
|
| 369 |
+
# Arms
|
| 370 |
+
dwg.add(dwg.line(start=(center_x - body_width//2, center_y), end=(center_x - 40, center_y + 20), stroke='black', stroke_width=3))
|
| 371 |
+
dwg.add(dwg.line(start=(center_x + body_width//2, center_y), end=(center_x + 40, center_y + 20), stroke='black', stroke_width=3))
|
| 372 |
+
|
| 373 |
+
# Legs
|
| 374 |
+
dwg.add(dwg.line(start=(center_x - 10, center_y + body_height - 10), end=(center_x - 20, center_y + body_height + 30), stroke='black', stroke_width=3))
|
| 375 |
+
dwg.add(dwg.line(start=(center_x + 10, center_y + body_height - 10), end=(center_x + 20, center_y + body_height + 30), stroke='black', stroke_width=3))
|
| 376 |
|
| 377 |
+
def add_animal_elements(self, dwg, width, height, features):
|
| 378 |
+
"""Add animal-like elements"""
|
| 379 |
+
center_x, center_y = width // 2, height // 2
|
| 380 |
+
|
| 381 |
+
# Body (oval)
|
| 382 |
+
dwg.add(dwg.ellipse(center=(center_x, center_y), r=(40, 25), fill='#8B4513', stroke='black', stroke_width=2))
|
| 383 |
+
|
| 384 |
+
# Head
|
| 385 |
+
dwg.add(dwg.circle(center=(center_x - 30, center_y - 10), r=20, fill='#A0522D', stroke='black', stroke_width=2))
|
| 386 |
+
|
| 387 |
+
# Legs
|
| 388 |
+
for i, x_offset in enumerate([-20, -10, 10, 20]):
|
| 389 |
+
dwg.add(dwg.line(
|
| 390 |
+
start=(center_x + x_offset, center_y + 25),
|
| 391 |
+
end=(center_x + x_offset, center_y + 45),
|
| 392 |
+
stroke='black',
|
| 393 |
+
stroke_width=3
|
| 394 |
+
))
|
| 395 |
+
|
| 396 |
+
# Tail
|
| 397 |
+
dwg.add(dwg.path(
|
| 398 |
+
d=f"M {center_x + 40},{center_y} Q {center_x + 60},{center_y - 20} {center_x + 50},{center_y - 35}",
|
| 399 |
+
stroke='black',
|
| 400 |
+
stroke_width=3,
|
| 401 |
+
fill='none'
|
| 402 |
+
))
|
| 403 |
+
|
| 404 |
+
def add_building_elements(self, dwg, width, height, features):
|
| 405 |
+
"""Add building-like elements"""
|
| 406 |
+
# Main building
|
| 407 |
+
building_width = width * 0.6
|
| 408 |
+
building_height = height * 0.7
|
| 409 |
+
x = (width - building_width) // 2
|
| 410 |
+
y = height - building_height - 10
|
| 411 |
+
|
| 412 |
dwg.add(dwg.rect(
|
| 413 |
+
insert=(x, y),
|
| 414 |
+
size=(building_width, building_height),
|
| 415 |
+
fill='#CD853F',
|
| 416 |
stroke='black',
|
| 417 |
stroke_width=2
|
| 418 |
))
|
| 419 |
|
| 420 |
# Roof
|
| 421 |
+
roof_points = [(x, y), (x + building_width//2, y - 30), (x + building_width, y)]
|
| 422 |
+
dwg.add(dwg.polygon(points=roof_points, fill='#8B0000', stroke='black', stroke_width=2))
|
| 423 |
+
|
| 424 |
+
# Windows
|
| 425 |
+
window_size = 15
|
| 426 |
+
for i in range(3):
|
| 427 |
+
for j in range(4):
|
| 428 |
+
wx = x + 15 + i * 30
|
| 429 |
+
wy = y + 15 + j * 25
|
| 430 |
+
if wy < y + building_height - 20:
|
| 431 |
+
dwg.add(dwg.rect(
|
| 432 |
+
insert=(wx, wy),
|
| 433 |
+
size=(window_size, window_size),
|
| 434 |
+
fill='#87CEEB',
|
| 435 |
+
stroke='black',
|
| 436 |
+
stroke_width=1
|
| 437 |
+
))
|
| 438 |
|
| 439 |
# Door
|
| 440 |
+
door_width = 20
|
| 441 |
+
door_height = 40
|
| 442 |
+
door_x = x + building_width//2 - door_width//2
|
| 443 |
+
door_y = y + building_height - door_height
|
|
|
|
| 444 |
dwg.add(dwg.rect(
|
| 445 |
insert=(door_x, door_y),
|
| 446 |
size=(door_width, door_height),
|
| 447 |
+
fill='#8B4513',
|
| 448 |
stroke='black',
|
| 449 |
stroke_width=2
|
| 450 |
))
|
| 451 |
|
| 452 |
+
def add_nature_elements(self, dwg, width, height, features):
|
| 453 |
+
"""Add nature-like elements"""
|
| 454 |
+
# Tree
|
| 455 |
+
center_x, center_y = width // 2, height // 2
|
| 456 |
|
| 457 |
# Trunk
|
| 458 |
+
trunk_width = 15
|
| 459 |
+
trunk_height = height // 3
|
| 460 |
+
trunk_x = center_x - trunk_width // 2
|
| 461 |
+
trunk_y = height - trunk_height - 10
|
| 462 |
+
|
| 463 |
dwg.add(dwg.rect(
|
| 464 |
+
insert=(trunk_x, trunk_y),
|
| 465 |
size=(trunk_width, trunk_height),
|
| 466 |
+
fill='#8B4513',
|
| 467 |
stroke='black',
|
| 468 |
+
stroke_width=1
|
| 469 |
))
|
| 470 |
|
| 471 |
+
# Crown (multiple circles for foliage)
|
| 472 |
+
crown_radius = 30
|
| 473 |
+
for i, (dx, dy) in enumerate([(-15, -20), (15, -20), (0, -35), (-10, -50), (10, -50)]):
|
| 474 |
+
dwg.add(dwg.circle(
|
| 475 |
+
center=(center_x + dx, center_y + dy),
|
| 476 |
+
r=crown_radius - i * 3,
|
| 477 |
+
fill='#228B22',
|
| 478 |
+
stroke='#006400',
|
| 479 |
+
stroke_width=1,
|
| 480 |
+
opacity=0.8
|
| 481 |
+
))
|
| 482 |
|
| 483 |
+
def add_vehicle_elements(self, dwg, width, height, features):
|
| 484 |
+
"""Add vehicle-like elements"""
|
| 485 |
+
center_x, center_y = width // 2, height // 2
|
|
|
|
|
|
|
|
|
|
| 486 |
|
| 487 |
# Car body
|
| 488 |
+
car_width = width * 0.6
|
| 489 |
+
car_height = height * 0.3
|
| 490 |
+
car_x = (width - car_width) // 2
|
| 491 |
+
car_y = center_y + 10
|
| 492 |
+
|
| 493 |
dwg.add(dwg.rect(
|
| 494 |
insert=(car_x, car_y),
|
| 495 |
size=(car_width, car_height),
|
| 496 |
+
fill='#FF4500',
|
| 497 |
stroke='black',
|
| 498 |
stroke_width=2,
|
| 499 |
rx=5
|
| 500 |
))
|
| 501 |
|
| 502 |
+
# Windshield
|
| 503 |
+
windshield_width = car_width * 0.6
|
| 504 |
+
windshield_height = car_height * 0.4
|
| 505 |
+
windshield_x = car_x + (car_width - windshield_width) // 2
|
| 506 |
+
windshield_y = car_y - windshield_height + 5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
|
| 508 |
+
dwg.add(dwg.rect(
|
| 509 |
+
insert=(windshield_x, windshield_y),
|
| 510 |
+
size=(windshield_width, windshield_height),
|
| 511 |
+
fill='#87CEEB',
|
|
|
|
| 512 |
stroke='black',
|
| 513 |
+
stroke_width=1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
))
|
| 515 |
|
| 516 |
+
# Wheels
|
| 517 |
+
wheel_radius = 12
|
| 518 |
+
wheel_y = car_y + car_height - 5
|
| 519 |
+
dwg.add(dwg.circle(center=(car_x + 25, wheel_y), r=wheel_radius, fill='black'))
|
| 520 |
+
dwg.add(dwg.circle(center=(car_x + car_width - 25, wheel_y), r=wheel_radius, fill='black'))
|
|
|
|
|
|
|
|
|
|
| 521 |
|
| 522 |
+
def add_abstract_elements(self, dwg, width, height, features):
|
| 523 |
+
"""Add abstract elements"""
|
| 524 |
+
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
|
| 526 |
+
for i in range(5):
|
| 527 |
+
shape_type = random.choice(['circle', 'rect', 'path'])
|
| 528 |
+
color = random.choice(colors)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
|
| 530 |
+
if shape_type == 'circle':
|
| 531 |
+
radius = random.randint(10, 30)
|
| 532 |
+
x = random.randint(radius, width - radius)
|
| 533 |
+
y = random.randint(radius, height - radius)
|
| 534 |
+
dwg.add(dwg.circle(center=(x, y), r=radius, fill=color, opacity=0.7))
|
| 535 |
+
elif shape_type == 'rect':
|
| 536 |
+
w = random.randint(20, 60)
|
| 537 |
+
h = random.randint(20, 60)
|
| 538 |
+
x = random.randint(0, width - w)
|
| 539 |
+
y = random.randint(0, height - h)
|
| 540 |
+
dwg.add(dwg.rect(insert=(x, y), size=(w, h), fill=color, opacity=0.7))
|
| 541 |
+
else:
|
| 542 |
+
# Random path
|
| 543 |
+
start_x = random.randint(0, width)
|
| 544 |
+
start_y = random.randint(0, height)
|
| 545 |
+
end_x = random.randint(0, width)
|
| 546 |
+
end_y = random.randint(0, height)
|
| 547 |
+
dwg.add(dwg.line(start=(start_x, start_y), end=(end_x, end_y), stroke=color, stroke_width=3))
|
| 548 |
|
| 549 |
+
def add_colored_elements(self, dwg, width, height, color_words):
|
| 550 |
+
"""Add elements with specific colors"""
|
| 551 |
color_map = {
|
| 552 |
'red': '#FF0000',
|
| 553 |
'blue': '#0000FF',
|
| 554 |
'green': '#00FF00',
|
| 555 |
'yellow': '#FFFF00',
|
| 556 |
+
'purple': '#800080',
|
| 557 |
+
'orange': '#FFA500'
|
| 558 |
}
|
| 559 |
|
| 560 |
+
center_x, center_y = width // 2, height // 2
|
| 561 |
|
| 562 |
+
for word in color_words:
|
| 563 |
+
if word in color_map:
|
| 564 |
+
color = color_map[word]
|
| 565 |
+
# Add a colored shape
|
| 566 |
+
dwg.add(dwg.circle(
|
| 567 |
+
center=(center_x + random.randint(-50, 50), center_y + random.randint(-50, 50)),
|
| 568 |
+
r=random.randint(15, 35),
|
| 569 |
+
fill=color,
|
| 570 |
+
opacity=0.8
|
| 571 |
+
))
|
| 572 |
+
|
| 573 |
+
def add_large_elements(self, dwg, width, height, features):
|
| 574 |
+
"""Add large-sized elements"""
|
| 575 |
+
center_x, center_y = width // 2, height // 2
|
| 576 |
+
|
| 577 |
+
# Large central element
|
| 578 |
dwg.add(dwg.circle(
|
| 579 |
+
center=(center_x, center_y),
|
| 580 |
+
r=min(width, height) // 3,
|
| 581 |
+
fill='#4A90E2',
|
| 582 |
stroke='black',
|
| 583 |
+
stroke_width=3
|
| 584 |
))
|
| 585 |
|
| 586 |
+
def add_small_elements(self, dwg, width, height, features):
|
| 587 |
+
"""Add small-sized elements"""
|
| 588 |
+
# Multiple small elements
|
| 589 |
+
for i in range(8):
|
| 590 |
+
x = random.randint(10, width - 10)
|
| 591 |
+
y = random.randint(10, height - 10)
|
| 592 |
+
dwg.add(dwg.circle(
|
| 593 |
+
center=(x, y),
|
| 594 |
+
r=random.randint(3, 8),
|
| 595 |
+
fill='#E74C3C',
|
| 596 |
+
opacity=0.7
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 597 |
))
|
| 598 |
|
| 599 |
+
def add_detailed_elements(self, dwg, width, height, features):
|
| 600 |
+
"""Add detailed elements for refinement"""
|
| 601 |
+
# Add more complex shapes and details
|
| 602 |
+
self.add_abstract_elements(dwg, width, height, features)
|
| 603 |
+
|
| 604 |
+
# Add decorative elements
|
| 605 |
+
center_x, center_y = width // 2, height // 2
|
| 606 |
+
for i in range(4):
|
| 607 |
+
angle = i * math.pi / 2
|
| 608 |
+
x = center_x + 40 * math.cos(angle)
|
| 609 |
+
y = center_y + 40 * math.sin(angle)
|
| 610 |
+
dwg.add(dwg.circle(center=(x, y), r=8, fill='#9B59B6', opacity=0.6))
|
| 611 |
|
| 612 |
+
def add_emphasized_element(self, dwg, word: str, weight: float, width: int, height: int):
|
| 613 |
+
"""Add emphasized element based on attention weight"""
|
| 614 |
+
center_x, center_y = width // 2, height // 2
|
| 615 |
+
|
| 616 |
+
# Scale size based on weight
|
| 617 |
+
base_size = 20
|
| 618 |
+
size = int(base_size * weight)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
|
| 620 |
dwg.add(dwg.circle(
|
| 621 |
+
center=(center_x + random.randint(-30, 30), center_y + random.randint(-30, 30)),
|
| 622 |
+
r=size,
|
| 623 |
+
fill='#FF6B6B',
|
| 624 |
+
opacity=min(1.0, weight / 2),
|
| 625 |
stroke='black',
|
| 626 |
stroke_width=2
|
| 627 |
))
|
| 628 |
|
| 629 |
+
def add_deemphasized_element(self, dwg, word: str, weight: float, width: int, height: int):
|
| 630 |
+
"""Add de-emphasized element based on attention weight"""
|
| 631 |
+
center_x, center_y = width // 2, height // 2
|
| 632 |
|
| 633 |
+
# Scale size based on weight
|
| 634 |
+
base_size = 15
|
| 635 |
+
size = int(base_size * weight)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 636 |
|
| 637 |
+
dwg.add(dwg.circle(
|
| 638 |
+
center=(center_x + random.randint(-40, 40), center_y + random.randint(-40, 40)),
|
| 639 |
+
r=max(3, size),
|
| 640 |
+
fill='#CCCCCC',
|
| 641 |
+
opacity=weight,
|
| 642 |
+
stroke='gray',
|
| 643 |
+
stroke_width=1
|
| 644 |
+
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 645 |
|
| 646 |
+
def add_content_based_on_prompt(self, dwg, prompt: str, width: int, height: int):
|
| 647 |
+
"""Add content based on prompt analysis"""
|
| 648 |
+
features = self.extract_semantic_features(prompt)
|
| 649 |
+
|
| 650 |
+
if any(word in prompt.lower() for word in ['person', 'people', 'human']):
|
| 651 |
+
self.add_person_elements(dwg, width, height, features)
|
| 652 |
+
elif any(word in prompt.lower() for word in ['animal', 'cat', 'dog']):
|
| 653 |
+
self.add_animal_elements(dwg, width, height, features)
|
| 654 |
+
elif any(word in prompt.lower() for word in ['house', 'building']):
|
| 655 |
+
self.add_building_elements(dwg, width, height, features)
|
| 656 |
+
elif any(word in prompt.lower() for word in ['tree', 'nature']):
|
| 657 |
+
self.add_nature_elements(dwg, width, height, features)
|
| 658 |
+
elif any(word in prompt.lower() for word in ['car', 'vehicle']):
|
| 659 |
+
self.add_vehicle_elements(dwg, width, height, features)
|
| 660 |
+
else:
|
| 661 |
+
self.add_abstract_elements(dwg, width, height, features)
|
| 662 |
|
| 663 |
+
def extract_semantic_features(self, prompt: str):
|
| 664 |
+
"""Extract semantic features from prompt"""
|
| 665 |
+
features = {
|
| 666 |
+
'detailed': False,
|
| 667 |
+
'simple': False,
|
| 668 |
+
'colorful': False,
|
| 669 |
+
'large': False,
|
| 670 |
+
'small': False
|
|
|
|
| 671 |
}
|
| 672 |
+
|
| 673 |
+
prompt_lower = prompt.lower()
|
| 674 |
+
|
| 675 |
+
if any(word in prompt_lower for word in ['detailed', 'complex', 'intricate']):
|
| 676 |
+
features['detailed'] = True
|
| 677 |
+
if any(word in prompt_lower for word in ['simple', 'minimal', 'basic']):
|
| 678 |
+
features['simple'] = True
|
| 679 |
+
if any(word in prompt_lower for word in ['colorful', 'bright', 'vibrant']):
|
| 680 |
+
features['colorful'] = True
|
| 681 |
+
if any(word in prompt_lower for word in ['large', 'big', 'huge']):
|
| 682 |
+
features['large'] = True
|
| 683 |
+
if any(word in prompt_lower for word in ['small', 'tiny', 'mini']):
|
| 684 |
+
features['small'] = True
|
| 685 |
+
|
| 686 |
+
return features
|
| 687 |
|
| 688 |
+
def svg_to_pil_image(self, svg_content: str, width: int, height: int):
|
| 689 |
"""Convert SVG content to PIL Image"""
|
| 690 |
try:
|
| 691 |
import cairosvg
|
|
|
|
| 692 |
|
| 693 |
# Convert SVG to PNG bytes
|
| 694 |
png_bytes = cairosvg.svg2png(
|
|
|
|
| 719 |
|
| 720 |
# Simple centered text
|
| 721 |
dwg.add(dwg.text(
|
| 722 |
+
f"DiffSketchEdit\n{prompt[:30]}...",
|
| 723 |
insert=(width/2, height/2),
|
| 724 |
text_anchor="middle",
|
| 725 |
+
font_size="12px",
|
| 726 |
fill="black"
|
| 727 |
))
|
| 728 |
|