| import os |
| import json |
| import platform |
| import gradio as gr |
| from PIL import Image |
| from gradio import context |
| from modules import shared, script_callbacks, scripts |
| from modules.shared import list_checkpoint_tiles, refresh_checkpoints |
| from modules.ui import create_refresh_button |
| from modules.generation_parameters_copypaste import ( |
| registered_param_bindings, |
| create_buttons, |
| register_paste_params_button, |
| connect_paste_params_buttons, |
| ParamBinding, |
| ) |
|
|
| from agent_scheduler.task_runner import ( |
| TaskRunner, |
| get_instance, |
| task_history_retenion_map, |
| ) |
| from agent_scheduler.helpers import ( |
| log, |
| compare_components_with_ids, |
| get_components_by_ids, |
| ) |
| from agent_scheduler.db import init as init_db, task_manager, TaskStatus |
| from agent_scheduler.api import regsiter_apis |
|
|
| task_runner: TaskRunner = None |
|
|
| checkpoint_current = "Current Checkpoint" |
| checkpoint_runtime = "Runtime Checkpoint" |
|
|
| ui_placement_as_tab = "As a tab" |
| ui_placement_append_to_main = "Append to main UI" |
|
|
| placement_under_generate = "Under Generate button" |
| placement_between_prompt_and_generate = "Between Prompt and Generate button" |
|
|
| task_filter_choices = ["All", "Bookmarked", "Done", "Failed", "Interrupted"] |
|
|
| is_macos = platform.system() == "Darwin" |
| enqueue_key_modifiers = [ |
| "Command" if is_macos else "Ctrl", |
| "Control" if is_macos else "Alt", |
| "Shift", |
| ] |
| enqueue_default_hotkey = enqueue_key_modifiers[0] + "+KeyE" |
| enqueue_key_codes = {} |
| enqueue_key_codes.update( |
| {chr(i): "Key" + chr(i) for i in range(ord("A"), ord("Z") + 1)} |
| ) |
| enqueue_key_codes.update( |
| {chr(i): "Digit" + chr(i) for i in range(ord("0"), ord("9") + 1)} |
| ) |
| enqueue_key_codes.update({"`": "Backquote", "Enter": "Enter"}) |
|
|
| init_db() |
|
|
|
|
| class Script(scripts.Script): |
| def __init__(self): |
| super().__init__() |
| script_callbacks.on_app_started(lambda block, _: self.on_app_started(block)) |
| self.checkpoint_override = checkpoint_current |
| self.generate_button = None |
| self.enqueue_row = None |
| self.checkpoint_dropdown = None |
| self.submit_button = None |
|
|
| def title(self): |
| return "Agent Scheduler" |
|
|
| def show(self, is_img2img): |
| return scripts.AlwaysVisible |
|
|
| def on_checkpoint_changed(self, checkpoint): |
| self.checkpoint_override = checkpoint |
|
|
| def after_component(self, component, **_kwargs): |
| generate_id = "txt2img_generate" if self.is_txt2img else "img2img_generate" |
| neg_id = "txt2img_neg_prompt" if self.is_txt2img else "img2img_neg_prompt" |
|
|
| if component.elem_id == generate_id: |
| self.generate_button = component |
| if ( |
| getattr(shared.opts, "queue_button_placement", placement_under_generate) |
| == placement_under_generate |
| ): |
| self.add_enqueue_button() |
| component.parent.children.pop() |
| component.parent.parent.add(self.enqueue_row) |
| return |
|
|
| if ( |
| component.elem_id == neg_id |
| and getattr(shared.opts, "queue_button_placement", placement_under_generate) |
| == placement_between_prompt_and_generate |
| ): |
| toprow = component.parent.parent.parent.parent.parent |
| self.add_enqueue_button() |
| component.parent.children.pop() |
| toprow.add(self.enqueue_row) |
|
|
| def on_app_started(self, block): |
| if self.generate_button is not None: |
| self.bind_enqueue_button(block) |
|
|
| def add_enqueue_button(self): |
| id_part = "img2img" if self.is_img2img else "txt2img" |
| with gr.Row(elem_id=f"{id_part}_enqueue_wrapper") as row: |
| self.enqueue_row = row |
| if not getattr(shared.opts, "queue_button_hide_checkpoint", True): |
| self.checkpoint_dropdown = gr.Dropdown( |
| choices=get_checkpoint_choices(), |
| value=checkpoint_current, |
| show_label=False, |
| interactive=True, |
| ) |
| create_refresh_button( |
| self.checkpoint_dropdown, |
| refresh_checkpoints, |
| lambda: {"choices": get_checkpoint_choices()}, |
| f"refresh_{id_part}_checkpoint", |
| ) |
| self.submit_button = gr.Button( |
| "Enqueue", elem_id=f"{id_part}_enqueue", variant="primary" |
| ) |
|
|
| def bind_enqueue_button(self, root: gr.Blocks): |
| generate = self.generate_button |
| is_img2img = self.is_img2img |
| dependencies: list[dict] = [ |
| x |
| for x in root.dependencies |
| if x["trigger"] == "click" and generate._id in x["targets"] |
| ] |
|
|
| dependency: dict = None |
| cnet_dependency: dict = None |
| UiControlNetUnit = None |
| for d in dependencies: |
| if len(d["outputs"]) == 1: |
| outputs = get_components_by_ids(root, d["outputs"]) |
| output = outputs[0] |
| if ( |
| isinstance(output, gr.State) |
| and type(output.value).__name__ == "UiControlNetUnit" |
| ): |
| cnet_dependency = d |
| UiControlNetUnit = type(output.value) |
|
|
| elif len(d["outputs"]) == 4: |
| dependency = d |
|
|
| with root: |
| if self.checkpoint_dropdown is not None: |
| self.checkpoint_dropdown.change( |
| fn=self.on_checkpoint_changed, inputs=[self.checkpoint_dropdown] |
| ) |
|
|
| fn_block = next( |
| fn |
| for fn in root.fns |
| if compare_components_with_ids(fn.inputs, dependency["inputs"]) |
| ) |
| fn = self.wrap_register_ui_task() |
| args = dict( |
| fn=fn, |
| _js="submit_enqueue_img2img" if is_img2img else "submit_enqueue", |
| inputs=fn_block.inputs, |
| outputs=None, |
| show_progress=False, |
| ) |
|
|
| self.submit_button.click(**args) |
|
|
| if cnet_dependency is not None: |
| cnet_fn_block = next( |
| fn |
| for fn in root.fns |
| if compare_components_with_ids(fn.inputs, cnet_dependency["inputs"]) |
| ) |
| self.submit_button.click( |
| fn=UiControlNetUnit, |
| inputs=cnet_fn_block.inputs, |
| outputs=cnet_fn_block.outputs, |
| queue=False, |
| ) |
|
|
| def wrap_register_ui_task(self): |
| def f(*args, **kwargs): |
| if len(args) == 0 and len(kwargs) == 0: |
| raise Exception("Invalid call") |
|
|
| if len(args) > 0 and type(args[0]) == str: |
| task_id = args[0] |
| else: |
| |
| return (None, "", "<p>Invalid params</p>", "") |
|
|
| checkpoint = None |
| if self.checkpoint_override == checkpoint_current: |
| checkpoint = shared.sd_model.sd_checkpoint_info.title |
| elif self.checkpoint_override != checkpoint_runtime: |
| checkpoint = self.checkpoint_override |
|
|
| task_runner.register_ui_task( |
| task_id, self.is_img2img, *args, checkpoint=checkpoint |
| ) |
| task_runner.execute_pending_tasks_threading() |
|
|
| return f |
|
|
|
|
| def get_checkpoint_choices(): |
| choices = [checkpoint_current, checkpoint_runtime] |
| choices.extend(list_checkpoint_tiles()) |
| return choices |
|
|
|
|
| def get_task_results(task_id: str, image_idx: int = None): |
| task = task_manager.get_task(task_id) |
|
|
| galerry = None |
| infotexts = None |
| if task is None: |
| pass |
| elif task.status != TaskStatus.DONE: |
| infotexts = f"Status: {task.status}" |
| if task.status == TaskStatus.FAILED and task.result: |
| infotexts += f"\nError: {task.result}" |
| elif task.status == TaskStatus.DONE: |
| try: |
| result: dict = json.loads(task.result) |
| images = result.get("images", []) |
| infos = result.get("infotexts", []) |
| galerry = ( |
| [Image.open(i) for i in images if os.path.exists(i)] |
| if image_idx is None |
| else gr.update() |
| ) |
| idx = image_idx if image_idx is not None else 0 |
| if len(infos) == len(images): |
| infotexts = infos[idx] |
| else: |
| infotexts = "\n".join(infos).split("Prompt: ")[1:][idx] |
|
|
| except Exception as e: |
| log.error(f"[AgentScheduler] Failed to load task result") |
| log.error(e) |
| infotexts = f"Failed to load task result: {str(e)}" |
|
|
| res = ( |
| gr.Textbox.update(infotexts, visible=infotexts is not None), |
| gr.Row.update(visible=galerry is not None), |
| ) |
| return res if image_idx is not None else (galerry,) + res |
|
|
|
|
| def on_ui_tab(**_kwargs): |
| with gr.Blocks(analytics_enabled=False) as scheduler_tab: |
| with gr.Tabs(elem_id="agent_scheduler_tabs"): |
| with gr.Tab( |
| "Task Queue", id=0, elem_id="agent_scheduler_pending_tasks_tab" |
| ): |
| with gr.Row(elem_id="agent_scheduler_pending_tasks_wrapper"): |
| with gr.Column(scale=1): |
| with gr.Group(elem_id="agent_scheduler_pending_tasks_actions"): |
| paused = getattr(shared.opts, "queue_paused", False) |
|
|
| gr.Button( |
| "Pause", |
| elem_id="agent_scheduler_action_pause", |
| variant="stop", |
| visible=not paused, |
| ) |
| gr.Button( |
| "Resume", |
| elem_id="agent_scheduler_action_resume", |
| variant="primary", |
| visible=paused, |
| ) |
| gr.Button( |
| "Refresh", |
| elem_id="agent_scheduler_action_reload", |
| variant="secondary", |
| ) |
| gr.HTML('<div id="agent_scheduler_action_search"></div>') |
| gr.HTML( |
| '<div id="agent_scheduler_pending_tasks_grid" class="ag-theme-alpine"></div>' |
| ) |
| with gr.Column(scale=1): |
| gr.Gallery( |
| elem_id="agent_scheduler_current_task_images", |
| label="Output", |
| show_label=False, |
| ).style(columns=2, object_fit="contain") |
| with gr.Tab("Task History", id=1, elem_id="agent_scheduler_history_tab"): |
| with gr.Row(elem_id="agent_scheduler_history_wrapper"): |
| with gr.Column(scale=1): |
| with gr.Group(elem_id="agent_scheduler_history_actions"): |
| gr.Button( |
| "Refresh", |
| elem_id="agent_scheduler_action_refresh_history", |
| elem_classes="agent_scheduler_action_refresh", |
| variant="secondary", |
| ) |
| status = gr.Dropdown( |
| elem_id="agent_scheduler_status_filter", |
| choices=task_filter_choices, |
| value="All", |
| show_label=False, |
| ) |
| gr.HTML( |
| '<div id="agent_scheduler_action_search_history"></div>' |
| ) |
| gr.HTML( |
| '<div id="agent_scheduler_history_tasks_grid" class="ag-theme-alpine"></div>' |
| ) |
| with gr.Column(scale=1, elem_id="agent_scheduler_history_results"): |
| galerry = gr.Gallery( |
| elem_id="agent_scheduler_history_gallery", |
| label="Output", |
| show_label=False, |
| ).style(columns=2, object_fit="contain", preview=True) |
| gen_info = gr.TextArea( |
| label="Generation Info", |
| elem_id=f"agent_scheduler_history_gen_info", |
| interactive=False, |
| visible=True, |
| lines=3, |
| ) |
| with gr.Row( |
| elem_id="agent_scheduler_history_result_actions", |
| visible=False, |
| ) as result_actions: |
| try: |
| send_to_buttons = create_buttons( |
| ["txt2img", "img2img", "inpaint", "extras"] |
| ) |
| except: |
| pass |
| selected_task = gr.Textbox( |
| elem_id="agent_scheduler_history_selected_task", |
| visible=False, |
| show_label=False, |
| ) |
| selected_task_id = gr.Textbox( |
| elem_id="agent_scheduler_history_selected_image", |
| visible=False, |
| show_label=False, |
| ) |
|
|
| |
| status.change( |
| fn=lambda x: None, |
| _js="agent_scheduler_status_filter_changed", |
| inputs=[status], |
| ) |
| selected_task.change( |
| fn=get_task_results, |
| inputs=[selected_task], |
| outputs=[galerry, gen_info, result_actions], |
| ) |
| selected_task_id.change( |
| fn=lambda x, y: get_task_results(x, image_idx=int(y)), |
| inputs=[selected_task, selected_task_id], |
| outputs=[gen_info, result_actions], |
| ) |
| try: |
| for paste_tabname, paste_button in send_to_buttons.items(): |
| register_paste_params_button( |
| ParamBinding( |
| paste_button=paste_button, |
| tabname=paste_tabname, |
| source_text_component=gen_info, |
| source_image_component=galerry, |
| ) |
| ) |
| except: |
| pass |
|
|
| return [(scheduler_tab, "Agent Scheduler", "agent_scheduler")] |
|
|
|
|
| def on_ui_settings(): |
| section = ("agent_scheduler", "Agent Scheduler") |
| shared.opts.add_option( |
| "queue_paused", |
| shared.OptionInfo( |
| False, |
| "Disable queue auto-processing", |
| gr.Checkbox, |
| {"interactive": True}, |
| section=section, |
| ), |
| ) |
| shared.opts.add_option( |
| "queue_button_placement", |
| shared.OptionInfo( |
| placement_under_generate, |
| "Queue button placement", |
| gr.Radio, |
| lambda: { |
| "choices": [ |
| placement_under_generate, |
| placement_between_prompt_and_generate, |
| ] |
| }, |
| section=section, |
| ), |
| ) |
| shared.opts.add_option( |
| "queue_button_hide_checkpoint", |
| shared.OptionInfo( |
| True, |
| "Hide the checkpoint dropdown", |
| gr.Checkbox, |
| {}, |
| section=section, |
| ), |
| ) |
| shared.opts.add_option( |
| "queue_history_retention_days", |
| shared.OptionInfo( |
| "30 days", |
| "Auto delete queue history (bookmarked tasks excluded)", |
| gr.Radio, |
| lambda: { |
| "choices": list(task_history_retenion_map.keys()), |
| }, |
| section=section, |
| ), |
| ) |
|
|
| def enqueue_keyboard_shortcut(disabled: bool, modifiers: list[str], key_code: str): |
| if disabled: |
| modifiers.insert(0, "Disabled") |
|
|
| shortcut = "+".join(sorted(modifiers) + [enqueue_key_codes[key_code]]) |
|
|
| return ( |
| shortcut, |
| gr.CheckboxGroup.update(interactive=not disabled), |
| gr.Dropdown.update(interactive=not disabled), |
| ) |
|
|
| def enqueue_keyboard_shortcut_ui(**_kwargs): |
| value = _kwargs.get("value", enqueue_default_hotkey) |
| parts = value.split("+") |
| key = parts.pop() |
| key_code_value = [k for k, v in enqueue_key_codes.items() if v == key] |
| modifiers = [m for m in parts if m in enqueue_key_modifiers] |
| disabled = "Disabled" in value |
|
|
| with gr.Group(elem_id="enqueue_keyboard_shortcut_wrapper"): |
| modifiers = gr.CheckboxGroup( |
| enqueue_key_modifiers, |
| value=modifiers, |
| label="Enqueue keyboard shortcut", |
| elem_id="enqueue_keyboard_shortcut_modifiers", |
| interactive=not disabled, |
| ) |
| key_code = gr.Dropdown( |
| choices=list(enqueue_key_codes.keys()), |
| value="E" if len(key_code_value) == 0 else key_code_value[0], |
| elem_id="enqueue_keyboard_shortcut_key", |
| label="Key", |
| interactive=not disabled, |
| ) |
| shortcut = gr.Textbox(**_kwargs) |
| disable = gr.Checkbox( |
| value=disabled, |
| label="Disable keyboard shortcut", |
| elem_id="enqueue_keyboard_shortcut_disable", |
| ) |
|
|
| modifiers.change( |
| fn=enqueue_keyboard_shortcut, |
| inputs=[disable, modifiers, key_code], |
| outputs=[shortcut, modifiers, key_code], |
| ) |
| key_code.change( |
| fn=enqueue_keyboard_shortcut, |
| inputs=[disable, modifiers, key_code], |
| outputs=[shortcut, modifiers, key_code], |
| ) |
| disable.change( |
| fn=enqueue_keyboard_shortcut, |
| inputs=[disable, modifiers, key_code], |
| outputs=[shortcut, modifiers, key_code], |
| ) |
|
|
| return shortcut |
|
|
| shared.opts.add_option( |
| "queue_keyboard_shortcut", |
| shared.OptionInfo( |
| enqueue_default_hotkey, |
| "Enqueue keyboard shortcut", |
| enqueue_keyboard_shortcut_ui, |
| { |
| "interactive": False, |
| }, |
| section=section, |
| ), |
| ) |
|
|
| shared.opts.add_option( |
| "queue_ui_placement", |
| shared.OptionInfo( |
| ui_placement_as_tab, |
| "Task queue UI placement", |
| gr.Radio, |
| lambda: { |
| "choices": [ |
| ui_placement_as_tab, |
| ui_placement_append_to_main, |
| ] |
| }, |
| section=section, |
| ), |
| ) |
|
|
|
|
| def on_app_started(block: gr.Blocks, app): |
| global task_runner |
| task_runner = get_instance(block) |
| task_runner.execute_pending_tasks_threading() |
| regsiter_apis(app, task_runner) |
|
|
| if ( |
| getattr(shared.opts, "queue_ui_placement", "") == ui_placement_append_to_main |
| and block |
| ): |
| with block: |
| with block.children[1]: |
| bindings = registered_param_bindings.copy() |
| registered_param_bindings.clear() |
| on_ui_tab() |
| connect_paste_params_buttons() |
| registered_param_bindings.extend(bindings) |
|
|
|
|
| if getattr(shared.opts, "queue_ui_placement", "") != ui_placement_append_to_main: |
| script_callbacks.on_ui_tabs(on_ui_tab) |
|
|
| script_callbacks.on_ui_settings(on_ui_settings) |
| script_callbacks.on_app_started(on_app_started) |
|
|