| |
|
|
| from torch import Tensor |
| import torch |
| from PIL import Image |
| import numpy as np |
|
|
| import os |
| import sys |
| import json |
| import folder_paths |
|
|
| |
| my_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
| |
| sys.path.append(my_dir) |
|
|
| |
| comfy_dir = os.path.abspath(os.path.join(my_dir, '..', '..')) |
|
|
| |
| sys.path.append(comfy_dir) |
|
|
| |
| import comfy.sd |
|
|
| |
| from tsc_sd import * |
|
|
| |
| loaded_objects = { |
| "ckpt": [], |
| "vae": [], |
| "lora": [] |
| } |
|
|
| |
| last_helds: dict[str, list] = { |
| "results": [], |
| "latent": [], |
| "images": [], |
| "vae_decode": [], |
| } |
|
|
| |
| def tensor2pil(image: torch.Tensor) -> Image.Image: |
| return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) |
|
|
| |
| def pil2tensor(image: Image.Image) -> torch.Tensor: |
| return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) |
|
|
| def extract_node_info(prompt, id, indirect_key=None): |
| |
| id = str(id) |
| node_id = None |
|
|
| |
| if indirect_key: |
| |
| if id in prompt and 'inputs' in prompt[id] and indirect_key in prompt[id]['inputs']: |
| |
| indirect_id = prompt[id]['inputs'][indirect_key][0] |
|
|
| |
| if indirect_id in prompt: |
| node_id = indirect_id |
| return prompt[indirect_id].get('class_type', None), node_id |
|
|
| |
| return None, None |
|
|
| |
| return prompt.get(id, {}).get('class_type', None), node_id |
|
|
| def extract_node_value(prompt, id, key): |
| |
| return prompt.get(str(id), {}).get('inputs', {}).get(key, None) |
|
|
| def print_loaded_objects_entries(id=None, prompt=None, show_id=False): |
| print("-" * 40) |
| if id is not None: |
| id = str(id) |
| if prompt is not None and id is not None: |
| node_name, _ = extract_node_info(prompt, id) |
| if show_id: |
| print(f"\033[36m{node_name} Models Cache: (node_id:{int(id)})\033[0m") |
| else: |
| print(f"\033[36m{node_name} Models Cache:\033[0m") |
| elif id is None: |
| print(f"\033[36mGlobal Models Cache:\033[0m") |
| else: |
| print(f"\033[36mModels Cache: \nnode_id:{int(id)}\033[0m") |
| entries_found = False |
| for key in ["ckpt", "vae", "lora"]: |
| entries_with_id = loaded_objects[key] if id is None else [entry for entry in loaded_objects[key] if id in entry[-1]] |
| if not entries_with_id: |
| continue |
| entries_found = True |
| print(f"{key.capitalize()}:") |
| for i, entry in enumerate(entries_with_id, 1): |
| if key == "lora": |
| lora_models_info = ', '.join(f"{os.path.splitext(os.path.basename(name))[0]}({round(strength_model, 2)},{round(strength_clip, 2)})" for name, strength_model, strength_clip in entry[0]) |
| base_ckpt_name = os.path.splitext(os.path.basename(entry[1]))[0] |
| if id is None: |
| associated_ids = ', '.join(map(str, entry[-1])) |
| print(f" [{i}] base_ckpt: {base_ckpt_name}, lora(mod,clip): {lora_models_info} (ids: {associated_ids})") |
| else: |
| print(f" [{i}] base_ckpt: {base_ckpt_name}, lora(mod,clip): {lora_models_info}") |
| else: |
| name_without_ext = os.path.splitext(os.path.basename(entry[0]))[0] |
| if id is None: |
| associated_ids = ', '.join(map(str, entry[-1])) |
| print(f" [{i}] {name_without_ext} (ids: {associated_ids})") |
| else: |
| print(f" [{i}] {name_without_ext}") |
| if not entries_found: |
| print("-") |
|
|
|
|
| |
| def globals_cleanup(prompt): |
| global loaded_objects |
| global last_helds |
|
|
| |
| for key in list(last_helds.keys()): |
| original_length = len(last_helds[key]) |
| last_helds[key] = [(value, id) for value, id in last_helds[key] if str(id) in prompt.keys()] |
| |
| |
|
|
| |
| for key in list(loaded_objects.keys()): |
| for i, tup in enumerate(list(loaded_objects[key])): |
| |
| id_array = [id for id in tup[-1] if str(id) in prompt.keys()] |
| if len(id_array) != len(tup[-1]): |
| if id_array: |
| loaded_objects[key][i] = tup[:-1] + (id_array,) |
| |
| else: |
| |
| loaded_objects[key].remove(tup) |
| |
|
|
| def load_checkpoint(ckpt_name, id, output_vae=True, cache=None, cache_overwrite=False): |
| """ |
| Searches for tuple index that contains ckpt_name in "ckpt" array of loaded_objects. |
| If found, extracts the model, clip, and vae from the loaded_objects. |
| If not found, loads the checkpoint, extracts the model, clip, and vae. |
| The id parameter represents the node ID and is used for caching models for the XY Plot node. |
| If the cache limit is reached for a specific id, clears the cache and returns the loaded model, clip, and vae without adding a new entry. |
| If there is cache space, adds the id to the ids list if it's not already there. |
| If there is cache space and the checkpoint was not found in loaded_objects, adds a new entry to loaded_objects. |
| |
| Parameters: |
| - ckpt_name: name of the checkpoint to load. |
| - id: an identifier for caching models for specific nodes. |
| - output_vae: boolean, if True loads the VAE too. |
| - cache (optional): an integer that specifies how many checkpoint entries with a given id can exist in loaded_objects. Defaults to None. |
| """ |
| global loaded_objects |
|
|
| for entry in loaded_objects["ckpt"]: |
| if entry[0] == ckpt_name: |
| _, model, clip, vae, ids = entry |
| cache_full = cache and len([entry for entry in loaded_objects["ckpt"] if id in entry[-1]]) >= cache |
|
|
| if cache_full: |
| clear_cache(id, cache, "ckpt") |
| elif id not in ids: |
| ids.append(id) |
|
|
| return model, clip, vae |
|
|
| ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) |
| out = load_checkpoint_guess_config_tsc(ckpt_path, output_vae, output_clip=True, |
| embedding_directory=folder_paths.get_folder_paths("embeddings")) |
| model = out[0] |
| clip = out[1] |
| vae = out[2] |
|
|
| if cache: |
| if len([entry for entry in loaded_objects["ckpt"] if id in entry[-1]]) < cache: |
| loaded_objects["ckpt"].append((ckpt_name, model, clip, vae, [id])) |
| else: |
| clear_cache(id, cache, "ckpt") |
| if cache_overwrite: |
| |
| for e in loaded_objects["ckpt"]: |
| if id in e[-1]: |
| e[-1].remove(id) |
| |
| if not e[-1]: |
| loaded_objects["ckpt"].remove(e) |
| break |
| loaded_objects["ckpt"].append((ckpt_name, model, clip, vae, [id])) |
|
|
| return model, clip, vae |
|
|
| def get_bvae_by_ckpt_name(ckpt_name): |
| for ckpt in loaded_objects["ckpt"]: |
| if ckpt[0] == ckpt_name: |
| return ckpt[3] |
| return None |
|
|
| def load_vae(vae_name, id, cache=None, cache_overwrite=False): |
| """ |
| Extracts the vae with a given name from the "vae" array in loaded_objects. |
| If the vae is not found, creates a new VAE object with the given name and adds it to the "vae" array. |
| Also stores the id parameter, which is used for caching models specifically for nodes with the given ID. |
| If the cache limit is reached for a specific id, returns the loaded vae without adding id or making a new entry in loaded_objects. |
| If there is cache space, and the id is not in the ids list, adds the id to the ids list. |
| If there is cache space, and the vae was not found in loaded_objects, adds a new entry to the loaded_objects. |
| |
| Parameters: |
| - vae_name: name of the VAE to load. |
| - id (optional): an identifier for caching models for specific nodes. Defaults to None. |
| - cache (optional): an integer that specifies how many vae entries with a given id can exist in loaded_objects. Defaults to None. |
| """ |
| global loaded_objects |
|
|
| for i, entry in enumerate(loaded_objects["vae"]): |
| if entry[0] == vae_name: |
| vae, ids = entry[1], entry[2] |
| if id not in ids: |
| if cache and len([entry for entry in loaded_objects["vae"] if id in entry[-1]]) >= cache: |
| return vae |
| ids.append(id) |
| if cache: |
| clear_cache(id, cache, "vae") |
| return vae |
|
|
| vae_path = folder_paths.get_full_path("vae", vae_name) |
| vae = comfy.sd.VAE(ckpt_path=vae_path) |
|
|
| if cache: |
| if len([entry for entry in loaded_objects["vae"] if id in entry[-1]]) < cache: |
| loaded_objects["vae"].append((vae_name, vae, [id])) |
| else: |
| clear_cache(id, cache, "vae") |
| if cache_overwrite: |
| |
| for e in loaded_objects["vae"]: |
| if id in e[-1]: |
| e[-1].remove(id) |
| |
| if not e[-1]: |
| loaded_objects["vae"].remove(e) |
| break |
| loaded_objects["vae"].append((vae_name, vae, [id])) |
|
|
| return vae |
|
|
| def load_lora(lora_params, ckpt_name, id, cache=None, ckpt_cache=None, cache_overwrite=False): |
| """ |
| Extracts the Lora model with a given name from the "lora" array in loaded_objects. |
| If the Lora model is not found or strength values changed or model changed, creates a new Lora object with the given name and adds it to the "lora" array. |
| Also stores the id parameter, which is used for caching models specifically for nodes with the given ID. |
| If the cache limit is reached for a specific id, clears the cache and returns the loaded Lora model and clip without adding a new entry. |
| If there is cache space, adds the id to the ids list if it's not already there. |
| If there is cache space and the Lora model was not found in loaded_objects, adds a new entry to loaded_objects. |
| |
| Parameters: |
| - lora_params: A list of tuples, where each tuple contains lora_name, strength_model, strength_clip. |
| - ckpt_name: name of the checkpoint from which the Lora model is created. |
| - id: an identifier for caching models for specific nodes. |
| - cache (optional): an integer that specifies how many Lora entries with a given id can exist in loaded_objects. Defaults to None. |
| """ |
| global loaded_objects |
|
|
| for entry in loaded_objects["lora"]: |
| |
| if set(entry[0]) == set(lora_params) and entry[1] == ckpt_name: |
|
|
| _, _, lora_model, lora_clip, ids = entry |
| cache_full = cache and len([entry for entry in loaded_objects["lora"] if id in entry[-1]]) >= cache |
|
|
| if cache_full: |
| clear_cache(id, cache, "lora") |
| elif id not in ids: |
| ids.append(id) |
|
|
| |
| for ckpt_entry in loaded_objects["ckpt"]: |
| if ckpt_entry[0] == ckpt_name: |
| _, _, _, _, ckpt_ids = ckpt_entry |
| ckpt_cache_full = ckpt_cache and len( |
| [ckpt_entry for ckpt_entry in loaded_objects["ckpt"] if id in ckpt_entry[-1]]) >= ckpt_cache |
|
|
| if ckpt_cache_full: |
| clear_cache(id, ckpt_cache, "ckpt") |
| elif id not in ckpt_ids: |
| ckpt_ids.append(id) |
|
|
| return lora_model, lora_clip |
|
|
| def recursive_load_lora(lora_params, ckpt, clip, id, ckpt_cache, cache_overwrite, folder_paths): |
| if len(lora_params) == 0: |
| return ckpt, clip |
|
|
| lora_name, strength_model, strength_clip = lora_params[0] |
| lora_path = folder_paths.get_full_path("loras", lora_name) |
| lora_model, lora_clip = load_lora_for_models_tsc(ckpt, clip, lora_path, strength_model, strength_clip) |
|
|
| |
| return recursive_load_lora(lora_params[1:], lora_model, lora_clip, id, ckpt_cache, cache_overwrite, folder_paths) |
|
|
| |
| lora_name, strength_model, strength_clip = lora_params[0] |
| ckpt, clip, _ = load_checkpoint(ckpt_name, id, cache=ckpt_cache, cache_overwrite=cache_overwrite) |
|
|
| lora_model, lora_clip = recursive_load_lora(lora_params, ckpt, clip, id, ckpt_cache, cache_overwrite, folder_paths) |
|
|
| if cache: |
| if len([entry for entry in loaded_objects["lora"] if id in entry[-1]]) < cache: |
| loaded_objects["lora"].append((lora_params, ckpt_name, lora_model, lora_clip, [id])) |
| else: |
| clear_cache(id, cache, "lora") |
| if cache_overwrite: |
| |
| for e in loaded_objects["lora"]: |
| if id in e[-1]: |
| e[-1].remove(id) |
| |
| if not e[-1]: |
| loaded_objects["lora"].remove(e) |
| break |
| loaded_objects["lora"].append((lora_params, ckpt_name, lora_model, lora_clip, [id])) |
|
|
| return lora_model, lora_clip |
|
|
| def clear_cache(id, cache, dict_name): |
| """ |
| Clear the cache for a specific id in a specific dictionary (either "ckpt" or "vae"). |
| If the cache limit is reached for a specific id, deletes the id from the oldest entry. |
| If the id array of the entry becomes empty, deletes the entry. |
| """ |
| |
| id_associated_entries = [entry for entry in loaded_objects[dict_name] if id in entry[-1]] |
| while len(id_associated_entries) > cache: |
| |
| older_entry = id_associated_entries[0] |
| |
| older_entry[-1].remove(id) |
| |
| if not older_entry[-1]: |
| loaded_objects[dict_name].remove(older_entry) |
| |
| id_associated_entries = [entry for entry in loaded_objects[dict_name] if id in entry[-1]] |
|
|
| def clear_cache_by_exception(node_id, vae_dict=None, ckpt_dict=None, lora_dict=None): |
| global loaded_objects |
|
|
| dict_mapping = { |
| "vae_dict": "vae", |
| "ckpt_dict": "ckpt", |
| "lora_dict": "lora" |
| } |
|
|
| for arg_name, arg_val in {"vae_dict": vae_dict, "ckpt_dict": ckpt_dict, "lora_dict": lora_dict}.items(): |
| if arg_val is None: |
| continue |
|
|
| dict_name = dict_mapping[arg_name] |
|
|
| for tuple_idx, tuple_item in enumerate(loaded_objects[dict_name].copy()): |
| if arg_name == "lora_dict": |
| |
| for lora_params, ckpt_name in arg_val: |
| |
| if set(lora_params) == set(tuple_item[0]) and ckpt_name == tuple_item[1]: |
| break |
| else: |
| if node_id in tuple_item[-1]: |
| tuple_item[-1].remove(node_id) |
| if not tuple_item[-1]: |
| loaded_objects[dict_name].remove(tuple_item) |
| continue |
| elif tuple_item[0] not in arg_val: |
| if node_id in tuple_item[-1]: |
| tuple_item[-1].remove(node_id) |
| if not tuple_item[-1]: |
| loaded_objects[dict_name].remove(tuple_item) |
|
|
|
|
| |
| def get_cache_numbers(node_name): |
| |
| my_dir = os.path.dirname(os.path.abspath(__file__)) |
| |
| settings_file = os.path.join(my_dir, 'node_settings.json') |
| |
| with open(settings_file, 'r') as file: |
| node_settings = json.load(file) |
| |
| model_cache_settings = node_settings.get(node_name, {}).get('model_cache', {}) |
| vae_cache = int(model_cache_settings.get('vae', 1)) |
| ckpt_cache = int(model_cache_settings.get('ckpt', 1)) |
| lora_cache = int(model_cache_settings.get('lora', 1)) |
| return vae_cache, ckpt_cache, lora_cache |
|
|
| def print_last_helds(id=None): |
| print("\n" + "-" * 40) |
| if id is not None: |
| id = str(id) |
| print(f"Node-specific Last Helds (node_id:{int(id)})") |
| else: |
| print(f"Global Last Helds:") |
| for key in ["results", "latent", "images", "vae_decode"]: |
| entries_with_id = last_helds[key] if id is None else [entry for entry in last_helds[key] if id == entry[-1]] |
| if not entries_with_id: |
| continue |
| print(f"{key.capitalize()}:") |
| for i, entry in enumerate(entries_with_id, 1): |
| if isinstance(entry[0], bool): |
| output = entry[0] |
| else: |
| output = len(entry[0]) |
| if id is None: |
| print(f" [{i}] Output: {output} (id: {entry[-1]})") |
| else: |
| print(f" [{i}] Output: {output}") |
| print("-" * 40) |
| print("\n") |