Mxdua89 commited on
Commit
4528946
·
verified ·
1 Parent(s): b0c7be2

Upload visualizer_drag_gradio.py

Browse files
Files changed (1) hide show
  1. visualizer_drag_gradio.py +871 -0
visualizer_drag_gradio.py ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ from argparse import ArgumentParser
4
+ from functools import partial
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+
11
+ import dnnlib
12
+ from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image,
13
+ get_latest_points_pair, get_valid_mask,
14
+ on_change_single_global_state)
15
+ from viz.renderer import Renderer, add_watermark_np
16
+
17
+ parser = ArgumentParser()
18
+ parser.add_argument('--share', action='store_true',default='True')
19
+ parser.add_argument('--cache-dir', type=str, default='./checkpoints')
20
+ parser.add_argument(
21
+ "--listen",
22
+ action="store_true",
23
+ help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests",
24
+ )
25
+ args = parser.parse_args()
26
+
27
+ cache_dir = args.cache_dir
28
+
29
+ device = 'cuda'
30
+
31
+
32
+ def reverse_point_pairs(points):
33
+ new_points = []
34
+ for p in points:
35
+ new_points.append([p[1], p[0]])
36
+ return new_points
37
+
38
+
39
+ def clear_state(global_state, target=None):
40
+ """Clear target history state from global_state
41
+ If target is not defined, points and mask will be both removed.
42
+ 1. set global_state['points'] as empty dict
43
+ 2. set global_state['mask'] as full-one mask.
44
+ """
45
+ if target is None:
46
+ target = ['point', 'mask']
47
+ if not isinstance(target, list):
48
+ target = [target]
49
+ if 'point' in target:
50
+ global_state['points'] = dict()
51
+ print('Clear Points State!')
52
+ if 'mask' in target:
53
+ image_raw = global_state["images"]["image_raw"]
54
+ global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]),
55
+ dtype=np.uint8)
56
+ print('Clear mask State!')
57
+
58
+ return global_state
59
+
60
+
61
+ def init_images(global_state):
62
+ """This function is called only ones with Gradio App is started.
63
+ 0. pre-process global_state, unpack value from global_state of need
64
+ 1. Re-init renderer
65
+ 2. run `renderer._render_drag_impl` with `is_drag=False` to generate
66
+ new image
67
+ 3. Assign images to global state and re-generate mask
68
+ """
69
+
70
+ if isinstance(global_state, gr.State):
71
+ state = global_state.value
72
+ else:
73
+ state = global_state
74
+
75
+ state['renderer'].init_network(
76
+ state['generator_params'], # res
77
+ valid_checkpoints_dict[state['pretrained_weight']], # pkl
78
+ state['params']['seed'], # w0_seed,
79
+ None, # w_load
80
+ state['params']['latent_space'] == 'w+', # w_plus
81
+ 'const',
82
+ state['params']['trunc_psi'], # trunc_psi,
83
+ state['params']['trunc_cutoff'], # trunc_cutoff,
84
+ None, # input_transform
85
+ state['params']['lr'] # lr,
86
+ )
87
+
88
+ state['renderer']._render_drag_impl(state['generator_params'],
89
+ is_drag=False,
90
+ to_pil=True)
91
+
92
+ init_image = state['generator_params'].image
93
+ state['images']['image_orig'] = init_image
94
+ state['images']['image_raw'] = init_image
95
+ state['images']['image_show'] = Image.fromarray(
96
+ add_watermark_np(np.array(init_image)))
97
+ state['mask'] = np.ones((init_image.size[1], init_image.size[0]),
98
+ dtype=np.uint8)
99
+ return global_state
100
+
101
+
102
+ def update_image_draw(image, points, mask, show_mask, global_state=None):
103
+
104
+ image_draw = draw_points_on_image(image, points)
105
+ if show_mask and mask is not None and not (mask == 0).all() and not (
106
+ mask == 1).all():
107
+ image_draw = draw_mask_on_image(image_draw, mask)
108
+
109
+ image_draw = Image.fromarray(add_watermark_np(np.array(image_draw)))
110
+ if global_state is not None:
111
+ global_state['images']['image_show'] = image_draw
112
+ return image_draw
113
+
114
+
115
+ def preprocess_mask_info(global_state, image):
116
+ """Function to handle mask information.
117
+ 1. last_mask is None: Do not need to change mask, return mask
118
+ 2. last_mask is not None:
119
+ 2.1 global_state is remove_mask:
120
+ 2.2 global_state is add_mask:
121
+ """
122
+ if isinstance(image, dict):
123
+ last_mask = get_valid_mask(image['mask'])
124
+ else:
125
+ last_mask = None
126
+ mask = global_state['mask']
127
+
128
+ # mask in global state is a placeholder with all 1.
129
+ if (mask == 1).all():
130
+ mask = last_mask
131
+
132
+ # last_mask = global_state['last_mask']
133
+ editing_mode = global_state['editing_state']
134
+
135
+ if last_mask is None:
136
+ return global_state
137
+
138
+ if editing_mode == 'remove_mask':
139
+ updated_mask = np.clip(mask - last_mask, 0, 1)
140
+ print(f'Last editing_state is {editing_mode}, do remove.')
141
+ elif editing_mode == 'add_mask':
142
+ updated_mask = np.clip(mask + last_mask, 0, 1)
143
+ print(f'Last editing_state is {editing_mode}, do add.')
144
+ else:
145
+ updated_mask = mask
146
+ print(f'Last editing_state is {editing_mode}, '
147
+ 'do nothing to mask.')
148
+
149
+ global_state['mask'] = updated_mask
150
+ # global_state['last_mask'] = None # clear buffer
151
+ return global_state
152
+
153
+
154
+ valid_checkpoints_dict = {
155
+ f.split('/')[-1].split('.')[0]: osp.join(cache_dir, f)
156
+ for f in os.listdir(cache_dir)
157
+ if (f.endswith('pkl') and osp.exists(osp.join(cache_dir, f)))
158
+ }
159
+ print(f'File under cache_dir ({cache_dir}):')
160
+ print(os.listdir(cache_dir))
161
+ print('Valid checkpoint file:')
162
+ print(valid_checkpoints_dict)
163
+
164
+ init_pkl = 'stylegan2_lions_512_pytorch'
165
+
166
+ with gr.Blocks() as app:
167
+
168
+ # renderer = Renderer()
169
+ global_state = gr.State({
170
+ "images": {
171
+ # image_orig: the original image, change with seed/model is changed
172
+ # image_raw: image with mask and points, change durning optimization
173
+ # image_show: image showed on screen
174
+ },
175
+ "temporal_params": {
176
+ # stop
177
+ },
178
+ 'mask':
179
+ None, # mask for visualization, 1 for editing and 0 for unchange
180
+ 'last_mask': None, # last edited mask
181
+ 'show_mask': True, # add button
182
+ "generator_params": dnnlib.EasyDict(),
183
+ "params": {
184
+ "seed": 0,
185
+ "motion_lambda": 20,
186
+ "r1_in_pixels": 3,
187
+ "r2_in_pixels": 12,
188
+ "magnitude_direction_in_pixels": 1.0,
189
+ "latent_space": "w+",
190
+ "trunc_psi": 0.7,
191
+ "trunc_cutoff": None,
192
+ "lr": 0.001,
193
+ },
194
+ "device": device,
195
+ "draw_interval": 1,
196
+ "renderer": Renderer(disable_timing=True),
197
+ "points": {},
198
+ "curr_point": None,
199
+ "curr_type_point": "start",
200
+ 'editing_state': 'add_points',
201
+ 'pretrained_weight': init_pkl
202
+ })
203
+
204
+ # init image
205
+ global_state = init_images(global_state)
206
+
207
+ with gr.Row():
208
+
209
+ with gr.Row():
210
+
211
+ # Left --> tools
212
+ with gr.Column(scale=3):
213
+
214
+ # Pickle
215
+ with gr.Row():
216
+
217
+ with gr.Column(scale=1, min_width=10):
218
+ gr.Markdown(value='Pickle', show_label=False)
219
+
220
+ with gr.Column(scale=4, min_width=10):
221
+ form_pretrained_dropdown = gr.Dropdown(
222
+ choices=list(valid_checkpoints_dict.keys()),
223
+ label="Pretrained Model",
224
+ value=init_pkl,
225
+ )
226
+
227
+ # Latent
228
+ with gr.Row():
229
+ with gr.Column(scale=1, min_width=10):
230
+ gr.Markdown(value='Latent', show_label=False)
231
+
232
+ with gr.Column(scale=4, min_width=10):
233
+ form_seed_number = gr.Number(
234
+ value=global_state.value['params']['seed'],
235
+ interactive=True,
236
+ label="Seed",
237
+ )
238
+ form_lr_number = gr.Number(
239
+ value=global_state.value["params"]["lr"],
240
+ interactive=True,
241
+ label="Step Size")
242
+
243
+ with gr.Row():
244
+ with gr.Column(scale=2, min_width=10):
245
+ form_reset_image = gr.Button("Reset Image")
246
+ with gr.Column(scale=3, min_width=10):
247
+ form_latent_space = gr.Radio(
248
+ ['w', 'w+'],
249
+ value=global_state.value['params']
250
+ ['latent_space'],
251
+ interactive=True,
252
+ label='Latent space to optimize',
253
+ show_label=False,
254
+ )
255
+
256
+ # Drag
257
+ with gr.Row():
258
+ with gr.Column(scale=1, min_width=10):
259
+ gr.Markdown(value='Drag', show_label=False)
260
+ with gr.Column(scale=4, min_width=10):
261
+ with gr.Row():
262
+ with gr.Column(scale=1, min_width=10):
263
+ enable_add_points = gr.Button('Add Points')
264
+ with gr.Column(scale=1, min_width=10):
265
+ undo_points = gr.Button('Reset Points')
266
+ with gr.Row():
267
+ with gr.Column(scale=1, min_width=10):
268
+ form_start_btn = gr.Button("Start")
269
+ with gr.Column(scale=1, min_width=10):
270
+ form_stop_btn = gr.Button("Stop")
271
+
272
+ form_steps_number = gr.Number(value=0,
273
+ label="Steps",
274
+ interactive=False)
275
+
276
+ # Mask
277
+ with gr.Row():
278
+ with gr.Column(scale=1, min_width=10):
279
+ gr.Markdown(value='Mask', show_label=False)
280
+ with gr.Column(scale=4, min_width=10):
281
+ enable_add_mask = gr.Button('Edit Flexible Area')
282
+ with gr.Row():
283
+ with gr.Column(scale=1, min_width=10):
284
+ form_reset_mask_btn = gr.Button("Reset mask")
285
+ with gr.Column(scale=1, min_width=10):
286
+ show_mask = gr.Checkbox(
287
+ label='Show Mask',
288
+ value=global_state.value['show_mask'],
289
+ show_label=False)
290
+
291
+ with gr.Row():
292
+ form_lambda_number = gr.Number(
293
+ value=global_state.value["params"]
294
+ ["motion_lambda"],
295
+ interactive=True,
296
+ label="Lambda",
297
+ )
298
+
299
+ form_draw_interval_number = gr.Number(
300
+ value=global_state.value["draw_interval"],
301
+ label="Draw Interval (steps)",
302
+ interactive=True,
303
+ visible=False)
304
+
305
+ # Right --> Image
306
+ with gr.Column(scale=8):
307
+ form_image = ImageMask(
308
+ value=global_state.value['images']['image_show'],
309
+ brush_radius=20).style(
310
+ width=768,
311
+ height=768) # NOTE: hard image size code here.
312
+ gr.Markdown("""
313
+ ## Quick Start
314
+
315
+ 1. Select desired `Pretrained Model` and adjust `Seed` to generate an
316
+ initial image.
317
+ 2. Click on image to add control points.
318
+ 3. Click `Start` and enjoy it!
319
+
320
+ ## Advance Usage
321
+
322
+ 1. Change `Step Size` to adjust learning rate in drag optimization.
323
+ 2. Select `w` or `w+` to change latent space to optimize:
324
+ * Optimize on `w` space may cause greater influence to the image.
325
+ * Optimize on `w+` space may work slower than `w`, but usually achieve
326
+ better results.
327
+ * Note that changing the latent space will reset the image, points and
328
+ mask (this has the same effect as `Reset Image` button).
329
+ 3. Click `Edit Flexible Area` to create a mask and constrain the
330
+ unmasked region to remain unchanged.
331
+ """)
332
+ gr.HTML("""
333
+ <style>
334
+ .container {
335
+ position: absolute;
336
+ height: 50px;
337
+ text-align: center;
338
+ line-height: 50px;
339
+ width: 100%;
340
+ }
341
+ </style>
342
+ <div class="container">
343
+ Gradio demo supported by
344
+ <img src="https://avatars.githubusercontent.com/u/10245193?s=200&v=4" height="20" width="20" style="display:inline;">
345
+ <a href="https://github.com/open-mmlab/mmagic">OpenMMLab MMagic</a>
346
+ </div>
347
+ """)
348
+
349
+ # Network & latents tab listeners
350
+ def on_change_pretrained_dropdown(pretrained_value, global_state):
351
+ """Function to handle model change.
352
+ 1. Set pretrained value to global_state
353
+ 2. Re-init images and clear all states
354
+ """
355
+
356
+ global_state['pretrained_weight'] = pretrained_value
357
+ init_images(global_state)
358
+ clear_state(global_state)
359
+
360
+ return global_state, global_state["images"]['image_show']
361
+
362
+ form_pretrained_dropdown.change(
363
+ on_change_pretrained_dropdown,
364
+ inputs=[form_pretrained_dropdown, global_state],
365
+ outputs=[global_state, form_image],
366
+ )
367
+
368
+ def on_click_reset_image(global_state):
369
+ """Reset image to the original one and clear all states
370
+ 1. Re-init images
371
+ 2. Clear all states
372
+ """
373
+
374
+ init_images(global_state)
375
+ clear_state(global_state)
376
+
377
+ return global_state, global_state['images']['image_show']
378
+
379
+ form_reset_image.click(
380
+ on_click_reset_image,
381
+ inputs=[global_state],
382
+ outputs=[global_state, form_image],
383
+ )
384
+
385
+ # Update parameters
386
+ def on_change_update_image_seed(seed, global_state):
387
+ """Function to handle generation seed change.
388
+ 1. Set seed to global_state
389
+ 2. Re-init images and clear all states
390
+ """
391
+
392
+ global_state["params"]["seed"] = int(seed)
393
+ init_images(global_state)
394
+ clear_state(global_state)
395
+
396
+ return global_state, global_state['images']['image_show']
397
+
398
+ form_seed_number.change(
399
+ on_change_update_image_seed,
400
+ inputs=[form_seed_number, global_state],
401
+ outputs=[global_state, form_image],
402
+ )
403
+
404
+ def on_click_latent_space(latent_space, global_state):
405
+ """Function to reset latent space to optimize.
406
+ NOTE: this function we reset the image and all controls
407
+ 1. Set latent-space to global_state
408
+ 2. Re-init images and clear all state
409
+ """
410
+
411
+ global_state['params']['latent_space'] = latent_space
412
+ init_images(global_state)
413
+ clear_state(global_state)
414
+
415
+ return global_state, global_state['images']['image_show']
416
+
417
+ form_latent_space.change(on_click_latent_space,
418
+ inputs=[form_latent_space, global_state],
419
+ outputs=[global_state, form_image])
420
+
421
+ # ==== Params
422
+ form_lambda_number.change(
423
+ partial(on_change_single_global_state, ["params", "motion_lambda"]),
424
+ inputs=[form_lambda_number, global_state],
425
+ outputs=[global_state],
426
+ )
427
+
428
+ def on_change_lr(lr, global_state):
429
+ if lr == 0:
430
+ print('lr is 0, do nothing.')
431
+ return global_state
432
+ else:
433
+ global_state["params"]["lr"] = lr
434
+ renderer = global_state['renderer']
435
+ renderer.update_lr(lr)
436
+ print('New optimizer: ')
437
+ print(renderer.w_optim)
438
+ return global_state
439
+
440
+ form_lr_number.change(
441
+ on_change_lr,
442
+ inputs=[form_lr_number, global_state],
443
+ outputs=[global_state],
444
+ )
445
+
446
+ def on_click_start(global_state, image):
447
+ p_in_pixels = []
448
+ t_in_pixels = []
449
+ valid_points = []
450
+
451
+ # handle of start drag in mask editing mode
452
+ global_state = preprocess_mask_info(global_state, image)
453
+
454
+ # Prepare the points for the inference
455
+ if len(global_state["points"]) == 0:
456
+ # yield on_click_start_wo_points(global_state, image)
457
+ image_raw = global_state['images']['image_raw']
458
+ update_image_draw(
459
+ image_raw,
460
+ global_state['points'],
461
+ global_state['mask'],
462
+ global_state['show_mask'],
463
+ global_state,
464
+ )
465
+
466
+ yield (
467
+ global_state,
468
+ 0,
469
+ global_state['images']['image_show'],
470
+ # gr.File.update(visible=False),
471
+ gr.Button.update(interactive=True),
472
+ gr.Button.update(interactive=True),
473
+ gr.Button.update(interactive=True),
474
+ gr.Button.update(interactive=True),
475
+ gr.Button.update(interactive=True),
476
+ # latent space
477
+ gr.Radio.update(interactive=True),
478
+ gr.Button.update(interactive=True),
479
+ # NOTE: disable stop button
480
+ gr.Button.update(interactive=False),
481
+
482
+ # update other comps
483
+ gr.Dropdown.update(interactive=True),
484
+ gr.Number.update(interactive=True),
485
+ gr.Number.update(interactive=True),
486
+ gr.Button.update(interactive=True),
487
+ gr.Button.update(interactive=True),
488
+ gr.Checkbox.update(interactive=True),
489
+ # gr.Number.update(interactive=True),
490
+ gr.Number.update(interactive=True),
491
+ )
492
+ else:
493
+
494
+ # Transform the points into torch tensors
495
+ for key_point, point in global_state["points"].items():
496
+ try:
497
+ p_start = point.get("start_temp", point["start"])
498
+ p_end = point["target"]
499
+
500
+ if p_start is None or p_end is None:
501
+ continue
502
+
503
+ except KeyError:
504
+ continue
505
+
506
+ p_in_pixels.append(p_start)
507
+ t_in_pixels.append(p_end)
508
+ valid_points.append(key_point)
509
+
510
+ mask = torch.tensor(global_state['mask']).float()
511
+ drag_mask = 1 - mask
512
+
513
+ renderer: Renderer = global_state["renderer"]
514
+ global_state['temporal_params']['stop'] = False
515
+ global_state['editing_state'] = 'running'
516
+
517
+ # reverse points order
518
+ p_to_opt = reverse_point_pairs(p_in_pixels)
519
+ t_to_opt = reverse_point_pairs(t_in_pixels)
520
+ print('Running with:')
521
+ print(f' Source: {p_in_pixels}')
522
+ print(f' Target: {t_in_pixels}')
523
+ step_idx = 0
524
+ while True:
525
+ if global_state["temporal_params"]["stop"]:
526
+ break
527
+
528
+ # do drage here!
529
+ renderer._render_drag_impl(
530
+ global_state['generator_params'],
531
+ p_to_opt, # point
532
+ t_to_opt, # target
533
+ drag_mask, # mask,
534
+ global_state['params']['motion_lambda'], # lambda_mask
535
+ reg=0,
536
+ feature_idx=5, # NOTE: do not support change for now
537
+ r1=global_state['params']['r1_in_pixels'], # r1
538
+ r2=global_state['params']['r2_in_pixels'], # r2
539
+ # random_seed = 0,
540
+ # noise_mode = 'const',
541
+ trunc_psi=global_state['params']['trunc_psi'],
542
+ # force_fp32 = False,
543
+ # layer_name = None,
544
+ # sel_channels = 3,
545
+ # base_channel = 0,
546
+ # img_scale_db = 0,
547
+ # img_normalize = False,
548
+ # untransform = False,
549
+ is_drag=True,
550
+ to_pil=True)
551
+
552
+ if step_idx % global_state['draw_interval'] == 0:
553
+ print('Current Source:')
554
+ for key_point, p_i, t_i in zip(valid_points, p_to_opt,
555
+ t_to_opt):
556
+ global_state["points"][key_point]["start_temp"] = [
557
+ p_i[1],
558
+ p_i[0],
559
+ ]
560
+ global_state["points"][key_point]["target"] = [
561
+ t_i[1],
562
+ t_i[0],
563
+ ]
564
+ start_temp = global_state["points"][key_point][
565
+ "start_temp"]
566
+ print(f' {start_temp}')
567
+
568
+ image_result = global_state['generator_params']['image']
569
+ image_draw = update_image_draw(
570
+ image_result,
571
+ global_state['points'],
572
+ global_state['mask'],
573
+ global_state['show_mask'],
574
+ global_state,
575
+ )
576
+ global_state['images']['image_raw'] = image_result
577
+
578
+ yield (
579
+ global_state,
580
+ step_idx,
581
+ global_state['images']['image_show'],
582
+ # gr.File.update(visible=False),
583
+ gr.Button.update(interactive=False),
584
+ gr.Button.update(interactive=False),
585
+ gr.Button.update(interactive=False),
586
+ gr.Button.update(interactive=False),
587
+ gr.Button.update(interactive=False),
588
+ # latent space
589
+ gr.Radio.update(interactive=False),
590
+ gr.Button.update(interactive=False),
591
+ # enable stop button in loop
592
+ gr.Button.update(interactive=True),
593
+
594
+ # update other comps
595
+ gr.Dropdown.update(interactive=False),
596
+ gr.Number.update(interactive=False),
597
+ gr.Number.update(interactive=False),
598
+ gr.Button.update(interactive=False),
599
+ gr.Button.update(interactive=False),
600
+ gr.Checkbox.update(interactive=False),
601
+ # gr.Number.update(interactive=False),
602
+ gr.Number.update(interactive=False),
603
+ )
604
+
605
+ # increate step
606
+ step_idx += 1
607
+
608
+ image_result = global_state['generator_params']['image']
609
+ global_state['images']['image_raw'] = image_result
610
+ image_draw = update_image_draw(image_result,
611
+ global_state['points'],
612
+ global_state['mask'],
613
+ global_state['show_mask'],
614
+ global_state)
615
+
616
+ # fp = NamedTemporaryFile(suffix=".png", delete=False)
617
+ # image_result.save(fp, "PNG")
618
+
619
+ global_state['editing_state'] = 'add_points'
620
+
621
+ yield (
622
+ global_state,
623
+ 0, # reset step to 0 after stop.
624
+ global_state['images']['image_show'],
625
+ # gr.File.update(visible=True, value=fp.name),
626
+ gr.Button.update(interactive=True),
627
+ gr.Button.update(interactive=True),
628
+ gr.Button.update(interactive=True),
629
+ gr.Button.update(interactive=True),
630
+ gr.Button.update(interactive=True),
631
+ # latent space
632
+ gr.Radio.update(interactive=True),
633
+ gr.Button.update(interactive=True),
634
+ # NOTE: disable stop button with loop finish
635
+ gr.Button.update(interactive=False),
636
+
637
+ # update other comps
638
+ gr.Dropdown.update(interactive=True),
639
+ gr.Number.update(interactive=True),
640
+ gr.Number.update(interactive=True),
641
+ gr.Checkbox.update(interactive=True),
642
+ gr.Number.update(interactive=True),
643
+ )
644
+
645
+ form_start_btn.click(
646
+ on_click_start,
647
+ inputs=[global_state, form_image],
648
+ outputs=[
649
+ global_state,
650
+ form_steps_number,
651
+ form_image,
652
+ # form_download_result_file,
653
+ # >>> buttons
654
+ form_reset_image,
655
+ enable_add_points,
656
+ enable_add_mask,
657
+ undo_points,
658
+ form_reset_mask_btn,
659
+ form_latent_space,
660
+ form_start_btn,
661
+ form_stop_btn,
662
+ # <<< buttonm
663
+ # >>> inputs comps
664
+ form_pretrained_dropdown,
665
+ form_seed_number,
666
+ form_lr_number,
667
+ show_mask,
668
+ form_lambda_number,
669
+ ],
670
+ )
671
+
672
+ def on_click_stop(global_state):
673
+ """Function to handle stop button is clicked.
674
+ 1. send a stop signal by set global_state["temporal_params"]["stop"] as True
675
+ 2. Disable Stop button
676
+ """
677
+ global_state["temporal_params"]["stop"] = True
678
+
679
+ return global_state, gr.Button.update(interactive=False)
680
+
681
+ form_stop_btn.click(on_click_stop,
682
+ inputs=[global_state],
683
+ outputs=[global_state, form_stop_btn])
684
+
685
+ form_draw_interval_number.change(
686
+ partial(
687
+ on_change_single_global_state,
688
+ "draw_interval",
689
+ map_transform=lambda x: int(x),
690
+ ),
691
+ inputs=[form_draw_interval_number, global_state],
692
+ outputs=[global_state],
693
+ )
694
+
695
+ def on_click_remove_point(global_state):
696
+ choice = global_state["curr_point"]
697
+ del global_state["points"][choice]
698
+
699
+ choices = list(global_state["points"].keys())
700
+
701
+ if len(choices) > 0:
702
+ global_state["curr_point"] = choices[0]
703
+
704
+ return (
705
+ gr.Dropdown.update(choices=choices, value=choices[0]),
706
+ global_state,
707
+ )
708
+
709
+ # Mask
710
+ def on_click_reset_mask(global_state):
711
+ global_state['mask'] = np.ones(
712
+ (
713
+ global_state["images"]["image_raw"].size[1],
714
+ global_state["images"]["image_raw"].size[0],
715
+ ),
716
+ dtype=np.uint8,
717
+ )
718
+ image_draw = update_image_draw(global_state['images']['image_raw'],
719
+ global_state['points'],
720
+ global_state['mask'],
721
+ global_state['show_mask'], global_state)
722
+ return global_state, image_draw
723
+
724
+ form_reset_mask_btn.click(
725
+ on_click_reset_mask,
726
+ inputs=[global_state],
727
+ outputs=[global_state, form_image],
728
+ )
729
+
730
+ # Image
731
+ def on_click_enable_draw(global_state, image):
732
+ """Function to start add mask mode.
733
+ 1. Preprocess mask info from last state
734
+ 2. Change editing state to add_mask
735
+ 3. Set curr image with points and mask
736
+ """
737
+ global_state = preprocess_mask_info(global_state, image)
738
+ global_state['editing_state'] = 'add_mask'
739
+ image_raw = global_state['images']['image_raw']
740
+ image_draw = update_image_draw(image_raw, global_state['points'],
741
+ global_state['mask'], True,
742
+ global_state)
743
+ return (global_state,
744
+ gr.Image.update(value=image_draw, interactive=True))
745
+
746
+ def on_click_remove_draw(global_state, image):
747
+ """Function to start remove mask mode.
748
+ 1. Preprocess mask info from last state
749
+ 2. Change editing state to remove_mask
750
+ 3. Set curr image with points and mask
751
+ """
752
+ global_state = preprocess_mask_info(global_state, image)
753
+ global_state['edinting_state'] = 'remove_mask'
754
+ image_raw = global_state['images']['image_raw']
755
+ image_draw = update_image_draw(image_raw, global_state['points'],
756
+ global_state['mask'], True,
757
+ global_state)
758
+ return (global_state,
759
+ gr.Image.update(value=image_draw, interactive=True))
760
+
761
+ enable_add_mask.click(on_click_enable_draw,
762
+ inputs=[global_state, form_image],
763
+ outputs=[
764
+ global_state,
765
+ form_image,
766
+ ])
767
+
768
+ def on_click_add_point(global_state, image: dict):
769
+ """Function switch from add mask mode to add points mode.
770
+ 1. Updaste mask buffer if need
771
+ 2. Change global_state['editing_state'] to 'add_points'
772
+ 3. Set current image with mask
773
+ """
774
+
775
+ global_state = preprocess_mask_info(global_state, image)
776
+ global_state['editing_state'] = 'add_points'
777
+ mask = global_state['mask']
778
+ image_raw = global_state['images']['image_raw']
779
+ image_draw = update_image_draw(image_raw, global_state['points'], mask,
780
+ global_state['show_mask'], global_state)
781
+
782
+ return (global_state,
783
+ gr.Image.update(value=image_draw, interactive=False))
784
+
785
+ enable_add_points.click(on_click_add_point,
786
+ inputs=[global_state, form_image],
787
+ outputs=[global_state, form_image])
788
+
789
+ def on_click_image(global_state, evt: gr.SelectData):
790
+ """This function only support click for point selection
791
+ """
792
+ xy = evt.index
793
+ if global_state['editing_state'] != 'add_points':
794
+ print(f'In {global_state["editing_state"]} state. '
795
+ 'Do not add points.')
796
+
797
+ return global_state, global_state['images']['image_show']
798
+
799
+ points = global_state["points"]
800
+
801
+ point_idx = get_latest_points_pair(points)
802
+ if point_idx is None:
803
+ points[0] = {'start': xy, 'target': None}
804
+ print(f'Click Image - Start - {xy}')
805
+ elif points[point_idx].get('target', None) is None:
806
+ points[point_idx]['target'] = xy
807
+ print(f'Click Image - Target - {xy}')
808
+ else:
809
+ points[point_idx + 1] = {'start': xy, 'target': None}
810
+ print(f'Click Image - Start - {xy}')
811
+
812
+ image_raw = global_state['images']['image_raw']
813
+ image_draw = update_image_draw(
814
+ image_raw,
815
+ global_state['points'],
816
+ global_state['mask'],
817
+ global_state['show_mask'],
818
+ global_state,
819
+ )
820
+
821
+ return global_state, image_draw
822
+
823
+ form_image.select(
824
+ on_click_image,
825
+ inputs=[global_state],
826
+ outputs=[global_state, form_image],
827
+ )
828
+
829
+ def on_click_clear_points(global_state):
830
+ """Function to handle clear all control points
831
+ 1. clear global_state['points'] (clear_state)
832
+ 2. re-init network
833
+ 2. re-draw image
834
+ """
835
+ clear_state(global_state, target='point')
836
+
837
+ renderer: Renderer = global_state["renderer"]
838
+ renderer.feat_refs = None
839
+
840
+ image_raw = global_state['images']['image_raw']
841
+ image_draw = update_image_draw(image_raw, {}, global_state['mask'],
842
+ global_state['show_mask'], global_state)
843
+ return global_state, image_draw
844
+
845
+ undo_points.click(on_click_clear_points,
846
+ inputs=[global_state],
847
+ outputs=[global_state, form_image])
848
+
849
+ def on_click_show_mask(global_state, show_mask):
850
+ """Function to control whether show mask on image."""
851
+ global_state['show_mask'] = show_mask
852
+
853
+ image_raw = global_state['images']['image_raw']
854
+ image_draw = update_image_draw(
855
+ image_raw,
856
+ global_state['points'],
857
+ global_state['mask'],
858
+ global_state['show_mask'],
859
+ global_state,
860
+ )
861
+ return global_state, image_draw
862
+
863
+ show_mask.change(
864
+ on_click_show_mask,
865
+ inputs=[global_state, show_mask],
866
+ outputs=[global_state, form_image],
867
+ )
868
+
869
+ gr.close_all()
870
+ app.queue(concurrency_count=3, max_size=20)
871
+ app.launch(share=args.share, server_name="0.0.0.0" if args.listen else "127.0.0.1")