import torch import torch.nn.functional as F from torchvision.transforms import functional as TF import numpy as np from utils import initialize_brushstrokes def sample_quadratic_bezier_curve(s, c, e, num_points=10): """ sample points along quadratic Bezier curve Inputs: s: start points [N, 2] c: control points [N, 2] e: end points [N, 2] num_points: Number of samples per curve Output: points on curves [N, num_points, 2] """ N = s.shape[0] t = torch.linspace(0.0, 1.0, num_points, device=s.device) # [num_points] t = t.unsqueeze(0).expand(N, -1) # [N, num_points] p0_x, p0_y = s[:, 0:1], s[:, 1:2] p1_x, p1_y = c[:, 0:1], c[:, 1:2] p2_x, p2_y = e[:, 0:1], e[:, 1:2] # quadratic bezier curve from paper: B(t) = (1-t)^2 * p0 + 2(1-t)t * p1 + t^2 * p2 # equivalent form: p1 + (1-t)^2 * (p0 - p1) + t^2 * (p2 - p1) x = p1_x + (1.0 - t) ** 2 * (p0_x - p1_x) + t ** 2 * (p2_x - p1_x) y = p1_y + (1.0 - t) ** 2 * (p0_y - p1_y) + t ** 2 * (p2_y - p1_y) return torch.stack([x, y], dim=-1) # [N, num_points, 2] def _knn_on_grid(locations, grid_points, k): """ Find k nearest brushstroke indices for each grid point using cdist. Args: locations: [N, 2] brushstroke center locations grid_points: [M, 2] grid query points k: number of neighbors Returns: indices: [M, k] indices into locations """ dists = torch.cdist(grid_points, locations) # get k nearest _, indices = torch.topk(dists, k, dim=1, largest=False) return indices # [M, k] @torch.jit.script def _render_strokes( curve_points: torch.Tensor, locations: torch.Tensor, colors: torch.Tensor, widths: torch.Tensor, indices: torch.Tensor, H: int, W: int, K: int, canvas_color: float, ) -> torch.Tensor: """ core rendering: given KNN indices, render brushstrokes onto canvas. implements Algorithm 1 from the paper. Inputs: curve_points: [N, S, 2] sampled points on each Bezier curve locations: [N, 2] brushstroke locations (clamped) colors: [N, 3] brushstroke colors (clamped 0-1) widths: [N, 1] brushstroke widths (exp-space) indices: [H, W, K] nearest brush indices per pixel H, W: canvas dimensions K: strokes per pixel canvas_color: background intensity Output: canvas: [H, W, 3] """ device = curve_points.device N, S, _ = curve_points.shape # gather curve points, colors, widths for nearest brushstrokes at each pixel flat_idx = indices.flatten() # [H*W*K] canvas_curves = curve_points[flat_idx].view(H, W, K, S, 2) canvas_colors = colors[flat_idx].view(H, W, K, 3) canvas_widths = widths[flat_idx].view(H, W, K, 1) # full resolution pixel grid t_H = torch.linspace(0.0, float(H - 1), H, device=device) t_W = torch.linspace(0.0, float(W - 1), W, device=device) P_y, P_x = torch.meshgrid(t_H, t_W, indexing="ij") P_full = torch.stack([P_x, P_y], dim=-1) # [H, W, 2] # compute distance from each pixel to nearest point on each curve segment seg_a = canvas_curves[:, :, :, :-1, :] # [H, W, K, S-1, 2] seg_b = canvas_curves[:, :, :, 1:, :] # [H, W, K, S-1, 2] seg_ba = seg_b - seg_a # [H, W, K, S-1, 2] # vector from segment start to pixel p_a = P_full[:, :, None, None, :] - seg_a # [H, W, K, S-1, 2] # project pixel onto each line segment, clamp to [0, 1] t = torch.sum(seg_ba * p_a, dim=-1) / (torch.sum(seg_ba ** 2, dim=-1) + 1e-8) t = torch.clamp(t, 0.0, 1.0) # closest point on each segment closest = seg_a + t.unsqueeze(-1) * seg_ba # [H, W, K, S-1, 2] # squared distance from pixel to closest point on each segment dist_sq = torch.sum((P_full[:, :, None, None, :] - closest) ** 2, dim=-1) # [H, W, K, S-1] # minimum distance across segments for each stroke D_per_stroke = torch.amin(dist_sq, dim=-1) # [H, W, K] # minimum distance across all K nearest strokes D = torch.amin(D_per_stroke, dim=-1) # [H, W] # softmax ranking: which stroke "wins" at each pixel (softmax over inverse distance) ranking = F.softmax(100000.0 * (1.0 / (1e-8 + D_per_stroke)), dim=-1) # [H, W, K] # weighted color and width I_colors = torch.einsum("hwkc,hwk->hwc", canvas_colors, ranking) # [H, W, 3] bs = torch.einsum("hwkc,hwk->hwc", canvas_widths, ranking) # [H, W, 1] # brushstroke alpha mask (sigmoid of width - distance) bs_mask = torch.sigmoid(bs - D.unsqueeze(-1)) # [H, W, 1] canvas = torch.ones_like(I_colors) * canvas_color I = I_colors * bs_mask + (1.0 - bs_mask) * canvas return I # [H, W, 3] def stroke_renderer(curve_points, locations, colors, widths, H, W, K, canvas_color): """ full differentiable brushstroke renderer (Algorithm 1). uses a coarse grid + KNN to efficiently find relevant strokes per pixel, then renders with distance-based alpha compositing. Inputs: curve_points: [N, S, 2] points on Bezier curves locations: [N, 2] stroke center locations colors: [N, 3] stroke colors widths: [N, 1] stroke widths (log-space, will be exp'd) H, W: canvas height and width K: number of nearest strokes to consider per pixel canvas_color: background color value Output: canvas: [H, W, 3] """ colors = torch.clamp(colors, 0.0, 1.0) loc_x = torch.clamp(locations[:, 0:1], 0, W - 1) loc_y = torch.clamp(locations[:, 1:2], 0, H - 1) locations = torch.cat([loc_x, loc_y], dim=1) widths = torch.exp(widths) device = curve_points.device # coarse grid for KNN (every 5 pixels) t_H = torch.linspace(0.0, float(H - 1), max(int(H // 5), 1), device=device) t_W = torch.linspace(0.0, float(W - 1), max(int(W // 5), 1), device=device) P_y, P_x = torch.meshgrid(t_H, t_W, indexing="ij") coarse_grid = torch.stack([P_x, P_y], dim=-1).view(-1, 2) # [G, 2] # find K nearest strokes for each coarse grid cell K_actual = min(K, locations.shape[0]) indices = _knn_on_grid(locations, coarse_grid, K_actual) # [G, K] # reshape to coarse grid and upscale to full resolution cH, cW = len(t_H), len(t_W) indices = indices.view(cH, cW, K_actual).permute(2, 0, 1) # [K, cH, cW] # Upsample index map with nearest interpolation while preserving integer semantics. indices = TF.resize( indices.float(), size=[H, W], interpolation=TF.InterpolationMode.NEAREST ).long() # [K, H, W] indices = indices.permute(1, 2, 0) # [H, W, K] return _render_strokes( curve_points, locations, colors, widths, indices, H, W, K_actual, canvas_color ) class BrushStrokeRenderer(torch.nn.Module): """ differentiable brushstroke renderer module. parameterizes N brushstrokes, each with: - location: [N, 2] center position on canvas - curve_s, curve_e, curve_c: [N, 2] start/end/control points (relative to location) - color: [N, 3] RGB color - width: [N, 1] stroke width (log-space) """ def __init__( self, canvas_height, canvas_width, num_strokes=5000, samples_per_curve=10, strokes_per_pixel=20, canvas_color="gray", length_scale=1.1, width_scale=0.1, content_img=None, ): super().__init__() if canvas_color == "gray": self.canvas_color = 0.5 elif canvas_color == "black": self.canvas_color = 0.0 elif canvas_color == "noise": self.canvas_color = 0.5 # fallback else: self.canvas_color = 1.0 self.canvas_height = canvas_height self.canvas_width = canvas_width self.num_strokes = num_strokes self.samples_per_curve = samples_per_curve self.strokes_per_pixel = strokes_per_pixel # initialize brushstrokes from content image or randomly if content_img is not None: location, s, e, c, width, color = initialize_brushstrokes( content_img, num_strokes, canvas_height, canvas_width, length_scale, width_scale, init="sp" ) else: location, s, e, c, width, color = initialize_brushstrokes( None, num_strokes, canvas_height, canvas_width, length_scale, width_scale, init="random" ) location = location[..., ::-1].copy() s = s[..., ::-1].copy() e = e[..., ::-1].copy() c = c[..., ::-1].copy() self.curve_s = torch.nn.Parameter(torch.from_numpy(s.astype(np.float32))) self.curve_e = torch.nn.Parameter(torch.from_numpy(e.astype(np.float32))) self.curve_c = torch.nn.Parameter(torch.from_numpy(c.astype(np.float32))) self.color = torch.nn.Parameter(torch.from_numpy(color.astype(np.float32))) self.location = torch.nn.Parameter(torch.from_numpy(location.astype(np.float32))) self.width = torch.nn.Parameter( torch.from_numpy(np.log(np.maximum(width, 1e-3)).astype(np.float32)) ) def forward(self): curve_points = sample_quadratic_bezier_curve( s=self.curve_s + self.location, e=self.curve_e + self.location, c=self.curve_c + self.location, num_points=self.samples_per_curve, ) canvas = stroke_renderer( curve_points, self.location, self.color, self.width, self.canvas_height, self.canvas_width, self.strokes_per_pixel, self.canvas_color, ) return canvas