| import itertools
|
| from typing import Sequence, Mapping, Dict
|
| from comfy_execution.graph import DynamicPrompt
|
|
|
| import nodes
|
|
|
| from comfy_execution.graph_utils import is_link
|
|
|
| NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
|
|
|
|
|
| def include_unique_id_in_input(class_type: str) -> bool:
|
| if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
|
| return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
| class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
| NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
|
| return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
|
|
| class CacheKeySet:
|
| def __init__(self, dynprompt, node_ids, is_changed_cache):
|
| self.keys = {}
|
| self.subcache_keys = {}
|
|
|
| def add_keys(self, node_ids):
|
| raise NotImplementedError()
|
|
|
| def all_node_ids(self):
|
| return set(self.keys.keys())
|
|
|
| def get_used_keys(self):
|
| return self.keys.values()
|
|
|
| def get_used_subcache_keys(self):
|
| return self.subcache_keys.values()
|
|
|
| def get_data_key(self, node_id):
|
| return self.keys.get(node_id, None)
|
|
|
| def get_subcache_key(self, node_id):
|
| return self.subcache_keys.get(node_id, None)
|
|
|
| class Unhashable:
|
| def __init__(self):
|
| self.value = float("NaN")
|
|
|
| def to_hashable(obj):
|
|
|
|
|
| if isinstance(obj, (int, float, str, bool, type(None))):
|
| return obj
|
| elif isinstance(obj, Mapping):
|
| return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
| elif isinstance(obj, Sequence):
|
| return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
|
| else:
|
|
|
| return Unhashable()
|
|
|
| class CacheKeySetID(CacheKeySet):
|
| def __init__(self, dynprompt, node_ids, is_changed_cache):
|
| super().__init__(dynprompt, node_ids, is_changed_cache)
|
| self.dynprompt = dynprompt
|
| self.add_keys(node_ids)
|
|
|
| def add_keys(self, node_ids):
|
| for node_id in node_ids:
|
| if node_id in self.keys:
|
| continue
|
| if not self.dynprompt.has_node(node_id):
|
| continue
|
| node = self.dynprompt.get_node(node_id)
|
| self.keys[node_id] = (node_id, node["class_type"])
|
| self.subcache_keys[node_id] = (node_id, node["class_type"])
|
|
|
| class CacheKeySetInputSignature(CacheKeySet):
|
| def __init__(self, dynprompt, node_ids, is_changed_cache):
|
| super().__init__(dynprompt, node_ids, is_changed_cache)
|
| self.dynprompt = dynprompt
|
| self.is_changed_cache = is_changed_cache
|
| self.add_keys(node_ids)
|
|
|
| def include_node_id_in_input(self) -> bool:
|
| return False
|
|
|
| def add_keys(self, node_ids):
|
| for node_id in node_ids:
|
| if node_id in self.keys:
|
| continue
|
| if not self.dynprompt.has_node(node_id):
|
| continue
|
| node = self.dynprompt.get_node(node_id)
|
| self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
| self.subcache_keys[node_id] = (node_id, node["class_type"])
|
|
|
| def get_node_signature(self, dynprompt, node_id):
|
| signature = []
|
| ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
| signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
| for ancestor_id in ancestors:
|
| signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
| return to_hashable(signature)
|
|
|
| def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
| if not dynprompt.has_node(node_id):
|
|
|
| return [float("NaN")]
|
| node = dynprompt.get_node(node_id)
|
| class_type = node["class_type"]
|
| class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
| signature = [class_type, self.is_changed_cache.get(node_id)]
|
| if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
| signature.append(node_id)
|
| inputs = node["inputs"]
|
| for key in sorted(inputs.keys()):
|
| if is_link(inputs[key]):
|
| (ancestor_id, ancestor_socket) = inputs[key]
|
| ancestor_index = ancestor_order_mapping[ancestor_id]
|
| signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
|
| else:
|
| signature.append((key, inputs[key]))
|
| return signature
|
|
|
|
|
|
|
| def get_ordered_ancestry(self, dynprompt, node_id):
|
| ancestors = []
|
| order_mapping = {}
|
| self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
|
| return ancestors, order_mapping
|
|
|
| def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
| if not dynprompt.has_node(node_id):
|
| return
|
| inputs = dynprompt.get_node(node_id)["inputs"]
|
| input_keys = sorted(inputs.keys())
|
| for key in input_keys:
|
| if is_link(inputs[key]):
|
| ancestor_id = inputs[key][0]
|
| if ancestor_id not in order_mapping:
|
| ancestors.append(ancestor_id)
|
| order_mapping[ancestor_id] = len(ancestors) - 1
|
| self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
|
|
| class BasicCache:
|
| def __init__(self, key_class):
|
| self.key_class = key_class
|
| self.initialized = False
|
| self.dynprompt: DynamicPrompt
|
| self.cache_key_set: CacheKeySet
|
| self.cache = {}
|
| self.subcaches = {}
|
|
|
| def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
| self.dynprompt = dynprompt
|
| self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
| self.is_changed_cache = is_changed_cache
|
| self.initialized = True
|
|
|
| def all_node_ids(self):
|
| assert self.initialized
|
| node_ids = self.cache_key_set.all_node_ids()
|
| for subcache in self.subcaches.values():
|
| node_ids = node_ids.union(subcache.all_node_ids())
|
| return node_ids
|
|
|
| def _clean_cache(self):
|
| preserve_keys = set(self.cache_key_set.get_used_keys())
|
| to_remove = []
|
| for key in self.cache:
|
| if key not in preserve_keys:
|
| to_remove.append(key)
|
| for key in to_remove:
|
| del self.cache[key]
|
|
|
| def _clean_subcaches(self):
|
| preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
|
|
|
| to_remove = []
|
| for key in self.subcaches:
|
| if key not in preserve_subcaches:
|
| to_remove.append(key)
|
| for key in to_remove:
|
| del self.subcaches[key]
|
|
|
| def clean_unused(self):
|
| assert self.initialized
|
| self._clean_cache()
|
| self._clean_subcaches()
|
|
|
| def _set_immediate(self, node_id, value):
|
| assert self.initialized
|
| cache_key = self.cache_key_set.get_data_key(node_id)
|
| self.cache[cache_key] = value
|
|
|
| def _get_immediate(self, node_id):
|
| if not self.initialized:
|
| return None
|
| cache_key = self.cache_key_set.get_data_key(node_id)
|
| if cache_key in self.cache:
|
| return self.cache[cache_key]
|
| else:
|
| return None
|
|
|
| def _ensure_subcache(self, node_id, children_ids):
|
| subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
| subcache = self.subcaches.get(subcache_key, None)
|
| if subcache is None:
|
| subcache = BasicCache(self.key_class)
|
| self.subcaches[subcache_key] = subcache
|
| subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
| return subcache
|
|
|
| def _get_subcache(self, node_id):
|
| assert self.initialized
|
| subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
| if subcache_key in self.subcaches:
|
| return self.subcaches[subcache_key]
|
| else:
|
| return None
|
|
|
| def recursive_debug_dump(self):
|
| result = []
|
| for key in self.cache:
|
| result.append({"key": key, "value": self.cache[key]})
|
| for key in self.subcaches:
|
| result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
|
| return result
|
|
|
| class HierarchicalCache(BasicCache):
|
| def __init__(self, key_class):
|
| super().__init__(key_class)
|
|
|
| def _get_cache_for(self, node_id):
|
| assert self.dynprompt is not None
|
| parent_id = self.dynprompt.get_parent_node_id(node_id)
|
| if parent_id is None:
|
| return self
|
|
|
| hierarchy = []
|
| while parent_id is not None:
|
| hierarchy.append(parent_id)
|
| parent_id = self.dynprompt.get_parent_node_id(parent_id)
|
|
|
| cache = self
|
| for parent_id in reversed(hierarchy):
|
| cache = cache._get_subcache(parent_id)
|
| if cache is None:
|
| return None
|
| return cache
|
|
|
| def get(self, node_id):
|
| cache = self._get_cache_for(node_id)
|
| if cache is None:
|
| return None
|
| return cache._get_immediate(node_id)
|
|
|
| def set(self, node_id, value):
|
| cache = self._get_cache_for(node_id)
|
| assert cache is not None
|
| cache._set_immediate(node_id, value)
|
|
|
| def ensure_subcache_for(self, node_id, children_ids):
|
| cache = self._get_cache_for(node_id)
|
| assert cache is not None
|
| return cache._ensure_subcache(node_id, children_ids)
|
|
|
| class LRUCache(BasicCache):
|
| def __init__(self, key_class, max_size=100):
|
| super().__init__(key_class)
|
| self.max_size = max_size
|
| self.min_generation = 0
|
| self.generation = 0
|
| self.used_generation = {}
|
| self.children = {}
|
|
|
| def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
| super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
| self.generation += 1
|
| for node_id in node_ids:
|
| self._mark_used(node_id)
|
|
|
| def clean_unused(self):
|
| while len(self.cache) > self.max_size and self.min_generation < self.generation:
|
| self.min_generation += 1
|
| to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
|
| for key in to_remove:
|
| del self.cache[key]
|
| del self.used_generation[key]
|
| if key in self.children:
|
| del self.children[key]
|
| self._clean_subcaches()
|
|
|
| def get(self, node_id):
|
| self._mark_used(node_id)
|
| return self._get_immediate(node_id)
|
|
|
| def _mark_used(self, node_id):
|
| cache_key = self.cache_key_set.get_data_key(node_id)
|
| if cache_key is not None:
|
| self.used_generation[cache_key] = self.generation
|
|
|
| def set(self, node_id, value):
|
| self._mark_used(node_id)
|
| return self._set_immediate(node_id, value)
|
|
|
| def ensure_subcache_for(self, node_id, children_ids):
|
|
|
| super()._ensure_subcache(node_id, children_ids)
|
|
|
| self.cache_key_set.add_keys(children_ids)
|
| self._mark_used(node_id)
|
| cache_key = self.cache_key_set.get_data_key(node_id)
|
| self.children[cache_key] = []
|
| for child_id in children_ids:
|
| self._mark_used(child_id)
|
| self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
| return self
|
|
|
|
|
| class DependencyAwareCache(BasicCache):
|
| """
|
| A cache implementation that tracks dependencies between nodes and manages
|
| their execution and caching accordingly. It extends the BasicCache class.
|
| Nodes are removed from this cache once all of their descendants have been
|
| executed.
|
| """
|
|
|
| def __init__(self, key_class):
|
| """
|
| Initialize the DependencyAwareCache.
|
|
|
| Args:
|
| key_class: The class used for generating cache keys.
|
| """
|
| super().__init__(key_class)
|
| self.descendants = {}
|
| self.ancestors = {}
|
| self.executed_nodes = set()
|
|
|
| def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
| """
|
| Clear the entire cache and rebuild the dependency graph.
|
|
|
| Args:
|
| dynprompt: The dynamic prompt object containing node information.
|
| node_ids: List of node IDs to initialize the cache for.
|
| is_changed_cache: Flag indicating if the cache has changed.
|
| """
|
|
|
| self.cache.clear()
|
| self.subcaches.clear()
|
| self.descendants.clear()
|
| self.ancestors.clear()
|
| self.executed_nodes.clear()
|
|
|
|
|
| super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
|
|
|
|
| self._build_dependency_graph(dynprompt, node_ids)
|
|
|
| def _build_dependency_graph(self, dynprompt, node_ids):
|
| """
|
| Build the dependency graph for all nodes.
|
|
|
| Args:
|
| dynprompt: The dynamic prompt object containing node information.
|
| node_ids: List of node IDs to build the graph for.
|
| """
|
| self.descendants.clear()
|
| self.ancestors.clear()
|
| for node_id in node_ids:
|
| self.descendants[node_id] = set()
|
| self.ancestors[node_id] = set()
|
|
|
| for node_id in node_ids:
|
| inputs = dynprompt.get_node(node_id)["inputs"]
|
| for input_data in inputs.values():
|
| if is_link(input_data):
|
| ancestor_id = input_data[0]
|
| self.descendants[ancestor_id].add(node_id)
|
| self.ancestors[node_id].add(ancestor_id)
|
|
|
| def set(self, node_id, value):
|
| """
|
| Mark a node as executed and store its value in the cache.
|
|
|
| Args:
|
| node_id: The ID of the node to store.
|
| value: The value to store for the node.
|
| """
|
| self._set_immediate(node_id, value)
|
| self.executed_nodes.add(node_id)
|
| self._cleanup_ancestors(node_id)
|
|
|
| def get(self, node_id):
|
| """
|
| Retrieve the cached value for a node.
|
|
|
| Args:
|
| node_id: The ID of the node to retrieve.
|
|
|
| Returns:
|
| The cached value for the node.
|
| """
|
| return self._get_immediate(node_id)
|
|
|
| def ensure_subcache_for(self, node_id, children_ids):
|
| """
|
| Ensure a subcache exists for a node and update dependencies.
|
|
|
| Args:
|
| node_id: The ID of the parent node.
|
| children_ids: List of child node IDs to associate with the parent node.
|
|
|
| Returns:
|
| The subcache object for the node.
|
| """
|
| subcache = super()._ensure_subcache(node_id, children_ids)
|
| for child_id in children_ids:
|
| self.descendants[node_id].add(child_id)
|
| self.ancestors[child_id].add(node_id)
|
| return subcache
|
|
|
| def _cleanup_ancestors(self, node_id):
|
| """
|
| Check if ancestors of a node can be removed from the cache.
|
|
|
| Args:
|
| node_id: The ID of the node whose ancestors are to be checked.
|
| """
|
| for ancestor_id in self.ancestors.get(node_id, []):
|
| if ancestor_id in self.executed_nodes:
|
|
|
| if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
|
| self._remove_node(ancestor_id)
|
|
|
| def _remove_node(self, node_id):
|
| """
|
| Remove a node from the cache.
|
|
|
| Args:
|
| node_id: The ID of the node to remove.
|
| """
|
| cache_key = self.cache_key_set.get_data_key(node_id)
|
| if cache_key in self.cache:
|
| del self.cache[cache_key]
|
| subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
| if subcache_key in self.subcaches:
|
| del self.subcaches[subcache_key]
|
|
|
| def clean_unused(self):
|
| """
|
| Clean up unused nodes. This is a no-op for this cache implementation.
|
| """
|
| pass
|
|
|
| def recursive_debug_dump(self):
|
| """
|
| Dump the cache and dependency graph for debugging.
|
|
|
| Returns:
|
| A list containing the cache state and dependency graph.
|
| """
|
| result = super().recursive_debug_dump()
|
| result.append({
|
| "descendants": self.descendants,
|
| "ancestors": self.ancestors,
|
| "executed_nodes": list(self.executed_nodes),
|
| })
|
| return result
|
|
|