sdgsdggds commited on
Commit
d326ccc
·
1 Parent(s): 95d1ac8

Upload scripts.py

Browse files
Files changed (1) hide show
  1. scripts.py +594 -0
scripts.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import sys
4
+ import traceback
5
+ from collections import namedtuple
6
+
7
+ import gradio as gr
8
+
9
+ from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
10
+
11
+ AlwaysVisible = object()
12
+
13
+
14
+ class PostprocessImageArgs:
15
+ def __init__(self, image):
16
+ self.image = image
17
+
18
+
19
+ class Script:
20
+ name = None
21
+ """script's internal name derived from title"""
22
+
23
+ filename = None
24
+ args_from = None
25
+ args_to = None
26
+ alwayson = False
27
+
28
+ is_txt2img = False
29
+ is_img2img = False
30
+
31
+ group = None
32
+ """A gr.Group component that has all script's UI inside it"""
33
+
34
+ infotext_fields = None
35
+ """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
36
+ parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
37
+ """
38
+
39
+ paste_field_names = None
40
+ """if set in ui(), this is a list of names of infotext fields; the fields will be sent through the
41
+ various "Send to <X>" buttons when clicked
42
+ """
43
+
44
+ api_info = None
45
+ """Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
46
+
47
+ def title(self):
48
+ """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
49
+
50
+ raise NotImplementedError()
51
+
52
+ def ui(self, is_img2img):
53
+ """this function should create gradio UI elements. See https://gradio.app/docs/#components
54
+ The return value should be an array of all components that are used in processing.
55
+ Values of those returned components will be passed to run() and process() functions.
56
+ """
57
+
58
+ pass
59
+
60
+ def show(self, is_img2img):
61
+ """
62
+ is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
63
+
64
+ This function should return:
65
+ - False if the script should not be shown in UI at all
66
+ - True if the script should be shown in UI if it's selected in the scripts dropdown
67
+ - script.AlwaysVisible if the script should be shown in UI at all times
68
+ """
69
+
70
+ return True
71
+
72
+ def run(self, p, *args):
73
+ """
74
+ This function is called if the script has been selected in the script dropdown.
75
+ It must do all processing and return the Processed object with results, same as
76
+ one returned by processing.process_images.
77
+
78
+ Usually the processing is done by calling the processing.process_images function.
79
+
80
+ args contains all values returned by components from ui()
81
+ """
82
+
83
+ pass
84
+
85
+ def process(self, p, *args):
86
+ """
87
+ This function is called before processing begins for AlwaysVisible scripts.
88
+ You can modify the processing object (p) here, inject hooks, etc.
89
+ args contains all values returned by components from ui()
90
+ """
91
+
92
+ pass
93
+
94
+ def before_process_batch(self, p, *args, **kwargs):
95
+ """
96
+ Called before extra networks are parsed from the prompt, so you can add
97
+ new extra network keywords to the prompt with this callback.
98
+
99
+ **kwargs will have those items:
100
+ - batch_number - index of current batch, from 0 to number of batches-1
101
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
102
+ - seeds - list of seeds for current batch
103
+ - subseeds - list of subseeds for current batch
104
+ """
105
+
106
+ pass
107
+
108
+ def process_batch(self, p, *args, **kwargs):
109
+ """
110
+ Same as process(), but called for every batch.
111
+
112
+ **kwargs will have those items:
113
+ - batch_number - index of current batch, from 0 to number of batches-1
114
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
115
+ - seeds - list of seeds for current batch
116
+ - subseeds - list of subseeds for current batch
117
+ """
118
+
119
+ pass
120
+
121
+ def postprocess_batch(self, p, *args, **kwargs):
122
+ """
123
+ Same as process_batch(), but called for every batch after it has been generated.
124
+
125
+ **kwargs will have same items as process_batch, and also:
126
+ - batch_number - index of current batch, from 0 to number of batches-1
127
+ - images - torch tensor with all generated images, with values ranging from 0 to 1;
128
+ """
129
+
130
+ pass
131
+
132
+ def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
133
+ """
134
+ Called for every image after it has been generated.
135
+ """
136
+
137
+ pass
138
+
139
+ def postprocess(self, p, processed, *args):
140
+ """
141
+ This function is called after processing ends for AlwaysVisible scripts.
142
+ args contains all values returned by components from ui()
143
+ """
144
+
145
+ pass
146
+
147
+ def before_component(self, component, **kwargs):
148
+ """
149
+ Called before a component is created.
150
+ Use elem_id/label fields of kwargs to figure out which component it is.
151
+ This can be useful to inject your own components somewhere in the middle of vanilla UI.
152
+ You can return created components in the ui() function to add them to the list of arguments for your processing functions
153
+ """
154
+
155
+ pass
156
+
157
+ def after_component(self, component, **kwargs):
158
+ """
159
+ Called after a component is created. Same as above.
160
+ """
161
+
162
+ pass
163
+
164
+ def describe(self):
165
+ """unused"""
166
+ return ""
167
+
168
+ def elem_id(self, item_id):
169
+ """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
170
+
171
+ need_tabname = self.show(True) == self.show(False)
172
+ tabkind = 'img2img' if self.is_img2img else 'txt2txt'
173
+ tabname = f"{tabkind}_" if need_tabname else ""
174
+ title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
175
+
176
+ return f'script_{tabname}{title}_{item_id}'
177
+
178
+
179
+ current_basedir = paths.script_path
180
+
181
+
182
+ def basedir():
183
+ """returns the base directory for the current script. For scripts in the main scripts directory,
184
+ this is the main directory (where webui.py resides), and for scripts in extensions directory
185
+ (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
186
+ """
187
+ return current_basedir
188
+
189
+
190
+ ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
191
+
192
+ scripts_data = []
193
+ postprocessing_scripts_data = []
194
+ ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
195
+
196
+
197
+ def list_scripts(scriptdirname, extension):
198
+ scripts_list = []
199
+
200
+ basedir = os.path.join(paths.script_path, scriptdirname)
201
+ if os.path.exists(basedir):
202
+ for filename in sorted(os.listdir(basedir)):
203
+ scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
204
+
205
+ for ext in extensions.active():
206
+ scripts_list += ext.list_files(scriptdirname, extension)
207
+
208
+ scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
209
+
210
+ return scripts_list
211
+
212
+
213
+ def list_files_with_name(filename):
214
+ res = []
215
+
216
+ dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
217
+
218
+ for dirpath in dirs:
219
+ if not os.path.isdir(dirpath):
220
+ continue
221
+
222
+ path = os.path.join(dirpath, filename)
223
+ if os.path.isfile(path):
224
+ res.append(path)
225
+
226
+ return res
227
+
228
+
229
+ def load_scripts():
230
+ global current_basedir
231
+ scripts_data.clear()
232
+ postprocessing_scripts_data.clear()
233
+ script_callbacks.clear_callbacks()
234
+
235
+ scripts_list = list_scripts("scripts", ".py")
236
+
237
+ syspath = sys.path
238
+
239
+ def register_scripts_from_module(module):
240
+ for script_class in module.__dict__.values():
241
+ if type(script_class) != type:
242
+ continue
243
+
244
+ if issubclass(script_class, Script):
245
+ scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
246
+ elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
247
+ postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
248
+
249
+ def orderby(basedir):
250
+ # 1st webui, 2nd extensions-builtin, 3rd extensions
251
+ priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
252
+ for key in priority:
253
+ if basedir.startswith(key):
254
+ return priority[key]
255
+ return 9999
256
+
257
+ for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
258
+ try:
259
+ if scriptfile.basedir != paths.script_path:
260
+ sys.path = [scriptfile.basedir] + sys.path
261
+ current_basedir = scriptfile.basedir
262
+
263
+ script_module = script_loading.load_module(scriptfile.path)
264
+ register_scripts_from_module(script_module)
265
+
266
+ except Exception:
267
+ print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
268
+ print(traceback.format_exc(), file=sys.stderr)
269
+
270
+ finally:
271
+ sys.path = syspath
272
+ current_basedir = paths.script_path
273
+
274
+ global scripts_txt2img, scripts_img2img, scripts_postproc
275
+
276
+ scripts_txt2img = ScriptRunner()
277
+ scripts_img2img = ScriptRunner()
278
+ scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
279
+
280
+
281
+ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
282
+ try:
283
+ res = func(*args, **kwargs)
284
+ return res
285
+ except Exception:
286
+ print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
287
+ print(traceback.format_exc(), file=sys.stderr)
288
+
289
+ return default
290
+
291
+
292
+ class ScriptRunner:
293
+ def __init__(self):
294
+ self.scripts = []
295
+ self.selectable_scripts = []
296
+ self.alwayson_scripts = []
297
+ self.titles = []
298
+ self.infotext_fields = []
299
+ self.paste_field_names = []
300
+
301
+ def initialize_scripts(self, is_img2img):
302
+ from modules import scripts_auto_postprocessing
303
+
304
+ self.scripts.clear()
305
+ self.alwayson_scripts.clear()
306
+ self.selectable_scripts.clear()
307
+
308
+ auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
309
+
310
+ for script_data in auto_processing_scripts + scripts_data:
311
+ script = script_data.script_class()
312
+ script.filename = script_data.path
313
+ script.is_txt2img = not is_img2img
314
+ script.is_img2img = is_img2img
315
+
316
+ visibility = script.show(script.is_img2img)
317
+
318
+ if visibility == AlwaysVisible:
319
+ self.scripts.append(script)
320
+ self.alwayson_scripts.append(script)
321
+ script.alwayson = True
322
+
323
+ elif visibility:
324
+ self.scripts.append(script)
325
+ self.selectable_scripts.append(script)
326
+
327
+ def setup_ui(self):
328
+ import modules.api.models as api_models
329
+
330
+ self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
331
+
332
+ inputs = [None]
333
+ inputs_alwayson = [True]
334
+
335
+ def create_script_ui(script, inputs, inputs_alwayson):
336
+ script.args_from = len(inputs)
337
+ script.args_to = len(inputs)
338
+
339
+ controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
340
+
341
+ if controls is None:
342
+ return
343
+
344
+ script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
345
+ api_args = []
346
+
347
+ for control in controls:
348
+ control.custom_script_source = os.path.basename(script.filename)
349
+
350
+ arg_info = api_models.ScriptArg(label=control.label or "")
351
+
352
+ for field in ("value", "minimum", "maximum", "step", "choices"):
353
+ v = getattr(control, field, None)
354
+ if v is not None:
355
+ setattr(arg_info, field, v)
356
+
357
+ api_args.append(arg_info)
358
+
359
+ script.api_info = api_models.ScriptInfo(
360
+ name=script.name,
361
+ is_img2img=script.is_img2img,
362
+ is_alwayson=script.alwayson,
363
+ args=api_args,
364
+ )
365
+
366
+ if script.infotext_fields is not None:
367
+ self.infotext_fields += script.infotext_fields
368
+
369
+ if script.paste_field_names is not None:
370
+ self.paste_field_names += script.paste_field_names
371
+
372
+ inputs += controls
373
+ inputs_alwayson += [script.alwayson for _ in controls]
374
+ script.args_to = len(inputs)
375
+
376
+ for script in self.alwayson_scripts:
377
+ with gr.Group() as group:
378
+ create_script_ui(script, inputs, inputs_alwayson)
379
+
380
+ script.group = group
381
+
382
+ dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
383
+ inputs[0] = dropdown
384
+
385
+ for script in self.selectable_scripts:
386
+ with gr.Group(visible=False) as group:
387
+ create_script_ui(script, inputs, inputs_alwayson)
388
+
389
+ script.group = group
390
+
391
+ def select_script(script_index):
392
+ selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
393
+
394
+ return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
395
+
396
+ def init_field(title):
397
+ """called when an initial value is set from ui-config.json to show script's UI components"""
398
+
399
+ if title == 'None':
400
+ return
401
+
402
+ script_index = self.titles.index(title)
403
+ self.selectable_scripts[script_index].group.visible = True
404
+
405
+ dropdown.init_field = init_field
406
+
407
+ dropdown.change(
408
+ fn=select_script,
409
+ inputs=[dropdown],
410
+ outputs=[script.group for script in self.selectable_scripts]
411
+ )
412
+
413
+ self.script_load_ctr = 0
414
+ def onload_script_visibility(params):
415
+ title = params.get('Script', None)
416
+ if title:
417
+ title_index = self.titles.index(title)
418
+ visibility = title_index == self.script_load_ctr
419
+ self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
420
+ return gr.update(visible=visibility)
421
+ else:
422
+ return gr.update(visible=False)
423
+
424
+ self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) )
425
+ self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] )
426
+
427
+ return inputs
428
+
429
+ def run(self, p, *args):
430
+ script_index = args[0]
431
+
432
+ if script_index == 0:
433
+ return None
434
+
435
+ script = self.selectable_scripts[script_index-1]
436
+
437
+ if script is None:
438
+ return None
439
+
440
+ script_args = args[script.args_from:script.args_to]
441
+ processed = script.run(p, *script_args)
442
+
443
+ shared.total_tqdm.clear()
444
+
445
+ return processed
446
+
447
+ def process(self, p):
448
+ for script in self.alwayson_scripts:
449
+ try:
450
+ script_args = p.script_args[script.args_from:script.args_to]
451
+ script.process(p, *script_args)
452
+ except Exception:
453
+ print(f"Error running process: {script.filename}", file=sys.stderr)
454
+ print(traceback.format_exc(), file=sys.stderr)
455
+
456
+ def before_process_batch(self, p, **kwargs):
457
+ for script in self.alwayson_scripts:
458
+ try:
459
+ script_args = p.script_args[script.args_from:script.args_to]
460
+ script.before_process_batch(p, *script_args, **kwargs)
461
+ except Exception:
462
+ print(f"Error running before_process_batch: {script.filename}", file=sys.stderr)
463
+ print(traceback.format_exc(), file=sys.stderr)
464
+
465
+ def process_batch(self, p, **kwargs):
466
+ for script in self.alwayson_scripts:
467
+ try:
468
+ script_args = p.script_args[script.args_from:script.args_to]
469
+ script.process_batch(p, *script_args, **kwargs)
470
+ except Exception:
471
+ print(f"Error running process_batch: {script.filename}", file=sys.stderr)
472
+ print(traceback.format_exc(), file=sys.stderr)
473
+
474
+ def postprocess(self, p, processed):
475
+ for script in self.alwayson_scripts:
476
+ try:
477
+ script_args = p.script_args[script.args_from:script.args_to]
478
+ script.postprocess(p, processed, *script_args)
479
+ except Exception:
480
+ print(f"Error running postprocess: {script.filename}", file=sys.stderr)
481
+ print(traceback.format_exc(), file=sys.stderr)
482
+
483
+ def postprocess_batch(self, p, images, **kwargs):
484
+ for script in self.alwayson_scripts:
485
+ try:
486
+ script_args = p.script_args[script.args_from:script.args_to]
487
+ script.postprocess_batch(p, *script_args, images=images, **kwargs)
488
+ except Exception:
489
+ print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
490
+ print(traceback.format_exc(), file=sys.stderr)
491
+
492
+ def postprocess_image(self, p, pp: PostprocessImageArgs):
493
+ for script in self.alwayson_scripts:
494
+ try:
495
+ script_args = p.script_args[script.args_from:script.args_to]
496
+ script.postprocess_image(p, pp, *script_args)
497
+ except Exception:
498
+ print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
499
+ print(traceback.format_exc(), file=sys.stderr)
500
+
501
+ def before_component(self, component, **kwargs):
502
+ for script in self.scripts:
503
+ try:
504
+ script.before_component(component, **kwargs)
505
+ except Exception:
506
+ print(f"Error running before_component: {script.filename}", file=sys.stderr)
507
+ print(traceback.format_exc(), file=sys.stderr)
508
+
509
+ def after_component(self, component, **kwargs):
510
+ for script in self.scripts:
511
+ try:
512
+ script.after_component(component, **kwargs)
513
+ except Exception:
514
+ print(f"Error running after_component: {script.filename}", file=sys.stderr)
515
+ print(traceback.format_exc(), file=sys.stderr)
516
+
517
+ def reload_sources(self, cache):
518
+ for si, script in list(enumerate(self.scripts)):
519
+ args_from = script.args_from
520
+ args_to = script.args_to
521
+ filename = script.filename
522
+
523
+ module = cache.get(filename, None)
524
+ if module is None:
525
+ module = script_loading.load_module(script.filename)
526
+ cache[filename] = module
527
+
528
+ for script_class in module.__dict__.values():
529
+ if type(script_class) == type and issubclass(script_class, Script):
530
+ self.scripts[si] = script_class()
531
+ self.scripts[si].filename = filename
532
+ self.scripts[si].args_from = args_from
533
+ self.scripts[si].args_to = args_to
534
+
535
+
536
+ scripts_txt2img: ScriptRunner = None
537
+ scripts_img2img: ScriptRunner = None
538
+ scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
539
+ scripts_current: ScriptRunner = None
540
+
541
+
542
+ def reload_script_body_only():
543
+ cache = {}
544
+ scripts_txt2img.reload_sources(cache)
545
+ scripts_img2img.reload_sources(cache)
546
+
547
+
548
+ reload_scripts = load_scripts # compatibility alias
549
+
550
+
551
+ def add_classes_to_gradio_component(comp):
552
+ """
553
+ this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
554
+ """
555
+
556
+ comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
557
+
558
+ if getattr(comp, 'multiselect', False):
559
+ comp.elem_classes.append('multiselect')
560
+
561
+
562
+
563
+ def IOComponent_init(self, *args, **kwargs):
564
+ if scripts_current is not None:
565
+ scripts_current.before_component(self, **kwargs)
566
+
567
+ script_callbacks.before_component_callback(self, **kwargs)
568
+
569
+ res = original_IOComponent_init(self, *args, **kwargs)
570
+
571
+ add_classes_to_gradio_component(self)
572
+
573
+ script_callbacks.after_component_callback(self, **kwargs)
574
+
575
+ if scripts_current is not None:
576
+ scripts_current.after_component(self, **kwargs)
577
+
578
+ return res
579
+
580
+
581
+ original_IOComponent_init = gr.components.IOComponent.__init__
582
+ gr.components.IOComponent.__init__ = IOComponent_init
583
+
584
+
585
+ def BlockContext_init(self, *args, **kwargs):
586
+ res = original_BlockContext_init(self, *args, **kwargs)
587
+
588
+ add_classes_to_gradio_component(self)
589
+
590
+ return res
591
+
592
+
593
+ original_BlockContext_init = gr.blocks.BlockContext.__init__
594
+ gr.blocks.BlockContext.__init__ = BlockContext_init