| import torch |
| import torch.nn.functional as F |
| from torchvision.transforms import functional as TF |
| import numpy as np |
|
|
| from utils import initialize_brushstrokes |
|
|
|
|
| def sample_quadratic_bezier_curve(s, c, e, num_points=10): |
| """ |
| sample points along quadratic Bezier curve |
| Inputs: |
| s: start points [N, 2] |
| c: control points [N, 2] |
| e: end points [N, 2] |
| num_points: Number of samples per curve |
| Output: |
| points on curves [N, num_points, 2] |
| """ |
| N = s.shape[0] |
| t = torch.linspace(0.0, 1.0, num_points, device=s.device) |
| t = t.unsqueeze(0).expand(N, -1) |
|
|
| p0_x, p0_y = s[:, 0:1], s[:, 1:2] |
| p1_x, p1_y = c[:, 0:1], c[:, 1:2] |
| p2_x, p2_y = e[:, 0:1], e[:, 1:2] |
|
|
| |
| |
| x = p1_x + (1.0 - t) ** 2 * (p0_x - p1_x) + t ** 2 * (p2_x - p1_x) |
| y = p1_y + (1.0 - t) ** 2 * (p0_y - p1_y) + t ** 2 * (p2_y - p1_y) |
| return torch.stack([x, y], dim=-1) |
|
|
|
|
| def _knn_on_grid(locations, grid_points, k): |
| """ |
| Find k nearest brushstroke indices for each grid point using cdist. |
| |
| Args: |
| locations: [N, 2] brushstroke center locations |
| grid_points: [M, 2] grid query points |
| k: number of neighbors |
| |
| Returns: |
| indices: [M, k] indices into locations |
| """ |
| dists = torch.cdist(grid_points, locations) |
| |
| _, indices = torch.topk(dists, k, dim=1, largest=False) |
| return indices |
|
|
|
|
| @torch.jit.script |
| def _render_strokes( |
| curve_points: torch.Tensor, |
| locations: torch.Tensor, |
| colors: torch.Tensor, |
| widths: torch.Tensor, |
| indices: torch.Tensor, |
| H: int, |
| W: int, |
| K: int, |
| canvas_color: float, |
| ) -> torch.Tensor: |
| """ |
| core rendering: given KNN indices, render brushstrokes onto canvas. |
| implements Algorithm 1 from the paper. |
| |
| Inputs: |
| curve_points: [N, S, 2] sampled points on each Bezier curve |
| locations: [N, 2] brushstroke locations (clamped) |
| colors: [N, 3] brushstroke colors (clamped 0-1) |
| widths: [N, 1] brushstroke widths (exp-space) |
| indices: [H, W, K] nearest brush indices per pixel |
| H, W: canvas dimensions |
| K: strokes per pixel |
| canvas_color: background intensity |
| |
| Output: |
| canvas: [H, W, 3] |
| """ |
| device = curve_points.device |
| N, S, _ = curve_points.shape |
|
|
| |
| flat_idx = indices.flatten() |
| canvas_curves = curve_points[flat_idx].view(H, W, K, S, 2) |
| canvas_colors = colors[flat_idx].view(H, W, K, 3) |
| canvas_widths = widths[flat_idx].view(H, W, K, 1) |
|
|
| |
| t_H = torch.linspace(0.0, float(H - 1), H, device=device) |
| t_W = torch.linspace(0.0, float(W - 1), W, device=device) |
| P_y, P_x = torch.meshgrid(t_H, t_W, indexing="ij") |
| P_full = torch.stack([P_x, P_y], dim=-1) |
|
|
| |
| seg_a = canvas_curves[:, :, :, :-1, :] |
| seg_b = canvas_curves[:, :, :, 1:, :] |
| seg_ba = seg_b - seg_a |
|
|
| |
| p_a = P_full[:, :, None, None, :] - seg_a |
|
|
| |
| t = torch.sum(seg_ba * p_a, dim=-1) / (torch.sum(seg_ba ** 2, dim=-1) + 1e-8) |
| t = torch.clamp(t, 0.0, 1.0) |
|
|
| |
| closest = seg_a + t.unsqueeze(-1) * seg_ba |
|
|
| |
| dist_sq = torch.sum((P_full[:, :, None, None, :] - closest) ** 2, dim=-1) |
|
|
| |
| D_per_stroke = torch.amin(dist_sq, dim=-1) |
|
|
| |
| D = torch.amin(D_per_stroke, dim=-1) |
|
|
| |
| ranking = F.softmax(100000.0 * (1.0 / (1e-8 + D_per_stroke)), dim=-1) |
|
|
| |
| I_colors = torch.einsum("hwkc,hwk->hwc", canvas_colors, ranking) |
| bs = torch.einsum("hwkc,hwk->hwc", canvas_widths, ranking) |
|
|
| |
| bs_mask = torch.sigmoid(bs - D.unsqueeze(-1)) |
|
|
| canvas = torch.ones_like(I_colors) * canvas_color |
| I = I_colors * bs_mask + (1.0 - bs_mask) * canvas |
| return I |
|
|
|
|
| def stroke_renderer(curve_points, locations, colors, widths, H, W, K, canvas_color): |
| """ |
| full differentiable brushstroke renderer (Algorithm 1). |
| uses a coarse grid + KNN to efficiently find relevant strokes per pixel, |
| then renders with distance-based alpha compositing. |
| |
| Inputs: |
| curve_points: [N, S, 2] points on Bezier curves |
| locations: [N, 2] stroke center locations |
| colors: [N, 3] stroke colors |
| widths: [N, 1] stroke widths (log-space, will be exp'd) |
| H, W: canvas height and width |
| K: number of nearest strokes to consider per pixel |
| canvas_color: background color value |
| |
| Output: |
| canvas: [H, W, 3] |
| """ |
| colors = torch.clamp(colors, 0.0, 1.0) |
| loc_x = torch.clamp(locations[:, 0:1], 0, W - 1) |
| loc_y = torch.clamp(locations[:, 1:2], 0, H - 1) |
| locations = torch.cat([loc_x, loc_y], dim=1) |
| widths = torch.exp(widths) |
|
|
| device = curve_points.device |
|
|
| |
| t_H = torch.linspace(0.0, float(H - 1), max(int(H // 5), 1), device=device) |
| t_W = torch.linspace(0.0, float(W - 1), max(int(W // 5), 1), device=device) |
| P_y, P_x = torch.meshgrid(t_H, t_W, indexing="ij") |
| coarse_grid = torch.stack([P_x, P_y], dim=-1).view(-1, 2) |
|
|
| |
| K_actual = min(K, locations.shape[0]) |
| indices = _knn_on_grid(locations, coarse_grid, K_actual) |
|
|
| |
| cH, cW = len(t_H), len(t_W) |
| indices = indices.view(cH, cW, K_actual).permute(2, 0, 1) |
| |
| indices = TF.resize( |
| indices.float(), size=[H, W], interpolation=TF.InterpolationMode.NEAREST |
| ).long() |
| indices = indices.permute(1, 2, 0) |
|
|
| return _render_strokes( |
| curve_points, locations, colors, widths, indices, H, W, K_actual, canvas_color |
| ) |
|
|
|
|
| class BrushStrokeRenderer(torch.nn.Module): |
| """ |
| differentiable brushstroke renderer module. |
| |
| parameterizes N brushstrokes, each with: |
| - location: [N, 2] center position on canvas |
| - curve_s, curve_e, curve_c: [N, 2] start/end/control points (relative to location) |
| - color: [N, 3] RGB color |
| - width: [N, 1] stroke width (log-space) |
| """ |
|
|
| def __init__( |
| self, |
| canvas_height, |
| canvas_width, |
| num_strokes=5000, |
| samples_per_curve=10, |
| strokes_per_pixel=20, |
| canvas_color="gray", |
| length_scale=1.1, |
| width_scale=0.1, |
| content_img=None, |
| ): |
| super().__init__() |
|
|
| if canvas_color == "gray": |
| self.canvas_color = 0.5 |
| elif canvas_color == "black": |
| self.canvas_color = 0.0 |
| elif canvas_color == "noise": |
| self.canvas_color = 0.5 |
| else: |
| self.canvas_color = 1.0 |
|
|
| self.canvas_height = canvas_height |
| self.canvas_width = canvas_width |
| self.num_strokes = num_strokes |
| self.samples_per_curve = samples_per_curve |
| self.strokes_per_pixel = strokes_per_pixel |
|
|
| |
| if content_img is not None: |
| location, s, e, c, width, color = initialize_brushstrokes( |
| content_img, num_strokes, canvas_height, canvas_width, |
| length_scale, width_scale, init="sp" |
| ) |
| else: |
| location, s, e, c, width, color = initialize_brushstrokes( |
| None, num_strokes, canvas_height, canvas_width, |
| length_scale, width_scale, init="random" |
| ) |
|
|
| location = location[..., ::-1].copy() |
| s = s[..., ::-1].copy() |
| e = e[..., ::-1].copy() |
| c = c[..., ::-1].copy() |
|
|
| self.curve_s = torch.nn.Parameter(torch.from_numpy(s.astype(np.float32))) |
| self.curve_e = torch.nn.Parameter(torch.from_numpy(e.astype(np.float32))) |
| self.curve_c = torch.nn.Parameter(torch.from_numpy(c.astype(np.float32))) |
| self.color = torch.nn.Parameter(torch.from_numpy(color.astype(np.float32))) |
| self.location = torch.nn.Parameter(torch.from_numpy(location.astype(np.float32))) |
| self.width = torch.nn.Parameter( |
| torch.from_numpy(np.log(np.maximum(width, 1e-3)).astype(np.float32)) |
| ) |
|
|
| def forward(self): |
| curve_points = sample_quadratic_bezier_curve( |
| s=self.curve_s + self.location, |
| e=self.curve_e + self.location, |
| c=self.curve_c + self.location, |
| num_points=self.samples_per_curve, |
| ) |
| canvas = stroke_renderer( |
| curve_points, |
| self.location, |
| self.color, |
| self.width, |
| self.canvas_height, |
| self.canvas_width, |
| self.strokes_per_pixel, |
| self.canvas_color, |
| ) |
| return canvas |
|
|