From e6bb7f03dd484ed06bb31b310157f3860bc9a977 Mon Sep 17 00:00:00 2001 From: Deep Mehta <42841935+deepme987@users.noreply.github.com> Date: Thu, 12 Mar 2026 19:56:21 -0700 Subject: [PATCH] =?UTF-8?q?Revert=20"Revert=20"feat:=20Add=20CacheProvider?= =?UTF-8?q?=20API=20for=20external=20distributed=20caching=20=E2=80=A6"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit d1d53c14be8442fca19aae978e944edad1935d46. --- comfy_api/latest/__init__.py | 35 ++ comfy_api/latest/_caching.py | 42 ++ comfy_execution/cache_provider.py | 138 ++++++ comfy_execution/caching.py | 177 +++++++- comfy_execution/graph.py | 6 +- execution.py | 141 +++--- .../execution_test/test_cache_provider.py | 403 ++++++++++++++++++ 7 files changed, 859 insertions(+), 83 deletions(-) create mode 100644 comfy_api/latest/_caching.py create mode 100644 comfy_execution/cache_provider.py create mode 100644 tests-unit/execution_test/test_cache_provider.py diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index f2399422b..04973fea0 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -25,6 +25,7 @@ class ComfyAPI_latest(ComfyAPIBase): super().__init__() self.node_replacement = self.NodeReplacement() self.execution = self.Execution() + self.caching = self.Caching() class NodeReplacement(ProxiedSingleton): async def register(self, node_replace: io.NodeReplace) -> None: @@ -84,6 +85,36 @@ class ComfyAPI_latest(ComfyAPIBase): image=to_display, ) + class Caching(ProxiedSingleton): + """ + External cache provider API for sharing cached node outputs + across ComfyUI instances. + + Example:: + + from comfy_api.latest import Caching + + class MyCacheProvider(Caching.CacheProvider): + async def on_lookup(self, context): + ... # check external storage + + async def on_store(self, context, value): + ... # store to external storage + + Caching.register_provider(MyCacheProvider()) + """ + from ._caching import CacheProvider, CacheContext, CacheValue + + async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None: + """Register an external cache provider. Providers are called in registration order.""" + from comfy_execution.cache_provider import register_cache_provider + register_cache_provider(provider) + + async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None: + """Unregister a previously registered cache provider.""" + from comfy_execution.cache_provider import unregister_cache_provider + unregister_cache_provider(provider) + class ComfyExtension(ABC): async def on_load(self) -> None: """ @@ -116,6 +147,9 @@ class Types: VOXEL = VOXEL File3D = File3D + +Caching = ComfyAPI_latest.Caching + ComfyAPI = ComfyAPI_latest # Create a synchronous version of the API @@ -135,6 +169,7 @@ __all__ = [ "Input", "InputImpl", "Types", + "Caching", "ComfyExtension", "io", "IO", diff --git a/comfy_api/latest/_caching.py b/comfy_api/latest/_caching.py new file mode 100644 index 000000000..30c8848cd --- /dev/null +++ b/comfy_api/latest/_caching.py @@ -0,0 +1,42 @@ +from abc import ABC, abstractmethod +from typing import Optional +from dataclasses import dataclass + + +@dataclass +class CacheContext: + node_id: str + class_type: str + cache_key_hash: str # SHA256 hex digest + + +@dataclass +class CacheValue: + outputs: list + ui: dict = None + + +class CacheProvider(ABC): + """Abstract base class for external cache providers. + Exceptions from provider methods are caught by the caller and never break execution. + """ + + @abstractmethod + async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: + """Called on local cache miss. Return CacheValue if found, None otherwise.""" + pass + + @abstractmethod + async def on_store(self, context: CacheContext, value: CacheValue) -> None: + """Called after local store. Dispatched via asyncio.create_task.""" + pass + + def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool: + """Return False to skip external caching for this node. Default: True.""" + return True + + def on_prompt_start(self, prompt_id: str) -> None: + pass + + def on_prompt_end(self, prompt_id: str) -> None: + pass diff --git a/comfy_execution/cache_provider.py b/comfy_execution/cache_provider.py new file mode 100644 index 000000000..d455d08e8 --- /dev/null +++ b/comfy_execution/cache_provider.py @@ -0,0 +1,138 @@ +from typing import Any, Optional, Tuple, List +import hashlib +import json +import logging +import threading + +# Public types — source of truth is comfy_api.latest._caching +from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported) + +_logger = logging.getLogger(__name__) + + +_providers: List[CacheProvider] = [] +_providers_lock = threading.Lock() +_providers_snapshot: Tuple[CacheProvider, ...] = () + + +def register_cache_provider(provider: CacheProvider) -> None: + """Register an external cache provider. Providers are called in registration order.""" + global _providers_snapshot + with _providers_lock: + if provider in _providers: + _logger.warning(f"Provider {provider.__class__.__name__} already registered") + return + _providers.append(provider) + _providers_snapshot = tuple(_providers) + _logger.debug(f"Registered cache provider: {provider.__class__.__name__}") + + +def unregister_cache_provider(provider: CacheProvider) -> None: + global _providers_snapshot + with _providers_lock: + try: + _providers.remove(provider) + _providers_snapshot = tuple(_providers) + _logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}") + except ValueError: + _logger.warning(f"Provider {provider.__class__.__name__} was not registered") + + +def _get_cache_providers() -> Tuple[CacheProvider, ...]: + return _providers_snapshot + + +def _has_cache_providers() -> bool: + return bool(_providers_snapshot) + + +def _clear_cache_providers() -> None: + global _providers_snapshot + with _providers_lock: + _providers.clear() + _providers_snapshot = () + + +def _canonicalize(obj: Any) -> Any: + # Convert to canonical JSON-serializable form with deterministic ordering. + # Frozensets have non-deterministic iteration order between Python sessions. + # Raises ValueError for non-cacheable types (Unhashable, unknown) so that + # _serialize_cache_key returns None and external caching is skipped. + if isinstance(obj, frozenset): + return ("__frozenset__", sorted( + [_canonicalize(item) for item in obj], + key=lambda x: json.dumps(x, sort_keys=True) + )) + elif isinstance(obj, set): + return ("__set__", sorted( + [_canonicalize(item) for item in obj], + key=lambda x: json.dumps(x, sort_keys=True) + )) + elif isinstance(obj, tuple): + return ("__tuple__", [_canonicalize(item) for item in obj]) + elif isinstance(obj, list): + return [_canonicalize(item) for item in obj] + elif isinstance(obj, dict): + return {"__dict__": sorted( + [[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()], + key=lambda x: json.dumps(x, sort_keys=True) + )} + elif isinstance(obj, (int, float, str, bool, type(None))): + return (type(obj).__name__, obj) + elif isinstance(obj, bytes): + return ("__bytes__", obj.hex()) + else: + raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}") + + +def _serialize_cache_key(cache_key: Any) -> Optional[str]: + # Returns deterministic SHA256 hex digest, or None on failure. + # Uses JSON (not pickle) because pickle is non-deterministic across sessions. + try: + canonical = _canonicalize(cache_key) + json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':')) + return hashlib.sha256(json_str.encode('utf-8')).hexdigest() + except Exception as e: + _logger.warning(f"Failed to serialize cache key: {e}") + return None + + +def _contains_self_unequal(obj: Any) -> bool: + # Local cache matches by ==. Values where not (x == x) (NaN, etc.) will + # never hit locally, but serialized form would match externally. Skip these. + try: + if not (obj == obj): + return True + except Exception: + return True + if isinstance(obj, (frozenset, tuple, list, set)): + return any(_contains_self_unequal(item) for item in obj) + if isinstance(obj, dict): + return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items()) + if hasattr(obj, 'value'): + return _contains_self_unequal(obj.value) + return False + + +def _estimate_value_size(value: CacheValue) -> int: + try: + import torch + except ImportError: + return 0 + + total = 0 + + def estimate(obj): + nonlocal total + if isinstance(obj, torch.Tensor): + total += obj.numel() * obj.element_size() + elif isinstance(obj, dict): + for v in obj.values(): + estimate(v) + elif isinstance(obj, (list, tuple)): + for item in obj: + estimate(item) + + for output in value.outputs: + estimate(output) + return total diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 326a279fc..750bddf2e 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -1,3 +1,4 @@ +import asyncio import bisect import gc import itertools @@ -154,6 +155,7 @@ class BasicCache: self.cache_key_set: CacheKeySet self.cache = {} self.subcaches = {} + self._pending_store_tasks: set = set() async def set_prompt(self, dynprompt, node_ids, is_changed_cache): self.dynprompt = dynprompt @@ -196,18 +198,134 @@ class BasicCache: def poll(self, **kwargs): pass - def _set_immediate(self, node_id, value): - assert self.initialized - cache_key = self.cache_key_set.get_data_key(node_id) - self.cache[cache_key] = value - - def _get_immediate(self, node_id): + def get_local(self, node_id): if not self.initialized: return None cache_key = self.cache_key_set.get_data_key(node_id) if cache_key in self.cache: return self.cache[cache_key] - else: + return None + + def set_local(self, node_id, value): + assert self.initialized + cache_key = self.cache_key_set.get_data_key(node_id) + self.cache[cache_key] = value + + async def _set_immediate(self, node_id, value): + assert self.initialized + cache_key = self.cache_key_set.get_data_key(node_id) + self.cache[cache_key] = value + + await self._notify_providers_store(node_id, cache_key, value) + + async def _get_immediate(self, node_id): + if not self.initialized: + return None + cache_key = self.cache_key_set.get_data_key(node_id) + + if cache_key in self.cache: + return self.cache[cache_key] + + external_result = await self._check_providers_lookup(node_id, cache_key) + if external_result is not None: + self.cache[cache_key] = external_result + return external_result + + return None + + async def _notify_providers_store(self, node_id, cache_key, value): + from comfy_execution.cache_provider import ( + _has_cache_providers, _get_cache_providers, + CacheValue, _contains_self_unequal, _logger + ) + + if not _has_cache_providers(): + return + if not self._is_external_cacheable_value(value): + return + if _contains_self_unequal(cache_key): + return + + context = self._build_context(node_id, cache_key) + if context is None: + return + cache_value = CacheValue(outputs=value.outputs, ui=value.ui) + + for provider in _get_cache_providers(): + try: + if provider.should_cache(context, cache_value): + task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value)) + self._pending_store_tasks.add(task) + task.add_done_callback(self._pending_store_tasks.discard) + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}") + + @staticmethod + async def _safe_provider_store(provider, context, cache_value): + from comfy_execution.cache_provider import _logger + try: + await provider.on_store(context, cache_value) + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}") + + async def _check_providers_lookup(self, node_id, cache_key): + from comfy_execution.cache_provider import ( + _has_cache_providers, _get_cache_providers, + CacheValue, _contains_self_unequal, _logger + ) + + if not _has_cache_providers(): + return None + if _contains_self_unequal(cache_key): + return None + + context = self._build_context(node_id, cache_key) + if context is None: + return None + + for provider in _get_cache_providers(): + try: + if not provider.should_cache(context): + continue + result = await provider.on_lookup(context) + if result is not None: + if not isinstance(result, CacheValue): + _logger.warning(f"Provider {provider.__class__.__name__} returned invalid type") + continue + if not isinstance(result.outputs, (list, tuple)): + _logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs") + continue + from execution import CacheEntry + return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs)) + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}") + + return None + + def _is_external_cacheable_value(self, value): + return hasattr(value, 'outputs') and hasattr(value, 'ui') + + def _get_class_type(self, node_id): + if not self.initialized or not self.dynprompt: + return '' + try: + return self.dynprompt.get_node(node_id).get('class_type', '') + except Exception: + return '' + + def _build_context(self, node_id, cache_key): + from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger + try: + cache_key_hash = _serialize_cache_key(cache_key) + if cache_key_hash is None: + return None + return CacheContext( + node_id=node_id, + class_type=self._get_class_type(node_id), + cache_key_hash=cache_key_hash, + ) + except Exception as e: + _logger.warning(f"Failed to build cache context for node {node_id}: {e}") return None async def _ensure_subcache(self, node_id, children_ids): @@ -257,16 +375,27 @@ class HierarchicalCache(BasicCache): return None return cache - def get(self, node_id): + async def get(self, node_id): cache = self._get_cache_for(node_id) if cache is None: return None - return cache._get_immediate(node_id) + return await cache._get_immediate(node_id) - def set(self, node_id, value): + def get_local(self, node_id): + cache = self._get_cache_for(node_id) + if cache is None: + return None + return BasicCache.get_local(cache, node_id) + + async def set(self, node_id, value): cache = self._get_cache_for(node_id) assert cache is not None - cache._set_immediate(node_id, value) + await cache._set_immediate(node_id, value) + + def set_local(self, node_id, value): + cache = self._get_cache_for(node_id) + assert cache is not None + BasicCache.set_local(cache, node_id, value) async def ensure_subcache_for(self, node_id, children_ids): cache = self._get_cache_for(node_id) @@ -287,10 +416,16 @@ class NullCache: def poll(self, **kwargs): pass - def get(self, node_id): + async def get(self, node_id): return None - def set(self, node_id, value): + def get_local(self, node_id): + return None + + async def set(self, node_id, value): + pass + + def set_local(self, node_id, value): pass async def ensure_subcache_for(self, node_id, children_ids): @@ -322,18 +457,18 @@ class LRUCache(BasicCache): del self.children[key] self._clean_subcaches() - def get(self, node_id): + async def get(self, node_id): self._mark_used(node_id) - return self._get_immediate(node_id) + return await self._get_immediate(node_id) def _mark_used(self, node_id): cache_key = self.cache_key_set.get_data_key(node_id) if cache_key is not None: self.used_generation[cache_key] = self.generation - def set(self, node_id, value): + async def set(self, node_id, value): self._mark_used(node_id) - return self._set_immediate(node_id, value) + return await self._set_immediate(node_id, value) async def ensure_subcache_for(self, node_id, children_ids): # Just uses subcaches for tracking 'live' nodes @@ -373,13 +508,13 @@ class RAMPressureCache(LRUCache): def clean_unused(self): self._clean_subcaches() - def set(self, node_id, value): + async def set(self, node_id, value): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() - super().set(node_id, value) + await super().set(node_id, value) - def get(self, node_id): + async def get(self, node_id): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() - return super().get(node_id) + return await super().get(node_id) def poll(self, ram_headroom): def _ram_gb(): diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 9d170b16e..c47f3c79b 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -204,12 +204,12 @@ class ExecutionList(TopologicalSort): self.execution_cache_listeners = {} def is_cached(self, node_id): - return self.output_cache.get(node_id) is not None + return self.output_cache.get_local(node_id) is not None def cache_link(self, from_node_id, to_node_id): if to_node_id not in self.execution_cache: self.execution_cache[to_node_id] = {} - self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) + self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id) if from_node_id not in self.execution_cache_listeners: self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id].add(to_node_id) @@ -221,7 +221,7 @@ class ExecutionList(TopologicalSort): if value is None: return None #Write back to the main cache on touch. - self.output_cache.set(from_node_id, value) + self.output_cache.set_local(from_node_id, value) return value def cache_update(self, node_id, value): diff --git a/execution.py b/execution.py index a7791efed..a8e8fc59f 100644 --- a/execution.py +++ b/execution.py @@ -40,6 +40,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a from comfy_execution.utils import CurrentNodeContext from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.latest import io, _io +from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger class ExecutionResult(Enum): @@ -418,7 +419,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - cached = caches.outputs.get(unique_id) + cached = await caches.outputs.get(unique_id) if cached is not None: if server.client_id is not None: cached_ui = cached.ui or {} @@ -474,10 +475,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) - obj = caches.objects.get(unique_id) + obj = await caches.objects.get(unique_id) if obj is None: obj = class_def() - caches.objects.set(unique_id, obj) + await caches.objects.set(unique_id, obj) if issubclass(class_def, _ComfyNodeInternal): lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None @@ -588,7 +589,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data) execution_list.cache_update(unique_id, cache_entry) - caches.outputs.set(unique_id, cache_entry) + await caches.outputs.set(unique_id, cache_entry) except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") @@ -684,6 +685,19 @@ class PromptExecutor: } self.add_message("execution_error", mes, broadcast=False) + def _notify_prompt_lifecycle(self, event: str, prompt_id: str): + if not _has_cache_providers(): + return + + for provider in _get_cache_providers(): + try: + if event == "start": + provider.on_prompt_start(prompt_id) + elif event == "end": + provider.on_prompt_end(prompt_id) + except Exception as e: + _cache_logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}") + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) @@ -700,66 +714,75 @@ class PromptExecutor: self.status_messages = [] self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) - with torch.inference_mode(): - dynamic_prompt = DynamicPrompt(prompt) - reset_progress_state(prompt_id, dynamic_prompt) - add_progress_handler(WebUIProgressHandler(self.server)) - is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) - for cache in self.caches.all: - await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) - cache.clean_unused() + self._notify_prompt_lifecycle("start", prompt_id) - cached_nodes = [] - for node_id in prompt: - if self.caches.outputs.get(node_id) is not None: - cached_nodes.append(node_id) + try: + with torch.inference_mode(): + dynamic_prompt = DynamicPrompt(prompt) + reset_progress_state(prompt_id, dynamic_prompt) + add_progress_handler(WebUIProgressHandler(self.server)) + is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) + for cache in self.caches.all: + await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) + cache.clean_unused() - comfy.model_management.cleanup_models_gc() - self.add_message("execution_cached", - { "nodes": cached_nodes, "prompt_id": prompt_id}, - broadcast=False) - pending_subgraph_results = {} - pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results - ui_node_outputs = {} - executed = set() - execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) - current_outputs = self.caches.outputs.all_node_ids() - for node_id in list(execute_outputs): - execution_list.add_node(node_id) + node_ids = list(prompt.keys()) + cache_results = await asyncio.gather( + *(self.caches.outputs.get(node_id) for node_id in node_ids) + ) + cached_nodes = [ + node_id for node_id, result in zip(node_ids, cache_results) + if result is not None + ] - while not execution_list.is_empty(): - node_id, error, ex = await execution_list.stage_node_execution() - if error is not None: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) - break + comfy.model_management.cleanup_models_gc() + self.add_message("execution_cached", + { "nodes": cached_nodes, "prompt_id": prompt_id}, + broadcast=False) + pending_subgraph_results = {} + pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results + ui_node_outputs = {} + executed = set() + execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) + current_outputs = self.caches.outputs.all_node_ids() + for node_id in list(execute_outputs): + execution_list.add_node(node_id) - assert node_id is not None, "Node ID should not be None at this point" - result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) - self.success = result != ExecutionResult.FAILURE - if result == ExecutionResult.FAILURE: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) - break - elif result == ExecutionResult.PENDING: - execution_list.unstage_node_execution() - else: # result == ExecutionResult.SUCCESS: - execution_list.complete_node_execution() - self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) - else: - # Only execute when the while-loop ends without break - self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) + while not execution_list.is_empty(): + node_id, error, ex = await execution_list.stage_node_execution() + if error is not None: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break - ui_outputs = {} - meta_outputs = {} - for node_id, ui_info in ui_node_outputs.items(): - ui_outputs[node_id] = ui_info["output"] - meta_outputs[node_id] = ui_info["meta"] - self.history_result = { - "outputs": ui_outputs, - "meta": meta_outputs, - } - self.server.last_node_id = None - if comfy.model_management.DISABLE_SMART_MEMORY: - comfy.model_management.unload_all_models() + assert node_id is not None, "Node ID should not be None at this point" + result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) + self.success = result != ExecutionResult.FAILURE + if result == ExecutionResult.FAILURE: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break + elif result == ExecutionResult.PENDING: + execution_list.unstage_node_execution() + else: # result == ExecutionResult.SUCCESS: + execution_list.complete_node_execution() + self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) + else: + # Only execute when the while-loop ends without break + self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) + + ui_outputs = {} + meta_outputs = {} + for node_id, ui_info in ui_node_outputs.items(): + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] + self.history_result = { + "outputs": ui_outputs, + "meta": meta_outputs, + } + self.server.last_node_id = None + if comfy.model_management.DISABLE_SMART_MEMORY: + comfy.model_management.unload_all_models() + finally: + self._notify_prompt_lifecycle("end", prompt_id) async def validate_inputs(prompt_id, prompt, item, validated): diff --git a/tests-unit/execution_test/test_cache_provider.py b/tests-unit/execution_test/test_cache_provider.py new file mode 100644 index 000000000..ac3814746 --- /dev/null +++ b/tests-unit/execution_test/test_cache_provider.py @@ -0,0 +1,403 @@ +"""Tests for external cache provider API.""" + +import importlib.util +import pytest +from typing import Optional + + +def _torch_available() -> bool: + """Check if PyTorch is available.""" + return importlib.util.find_spec("torch") is not None + + +from comfy_execution.cache_provider import ( + CacheProvider, + CacheContext, + CacheValue, + register_cache_provider, + unregister_cache_provider, + _get_cache_providers, + _has_cache_providers, + _clear_cache_providers, + _serialize_cache_key, + _contains_self_unequal, + _estimate_value_size, + _canonicalize, +) + + +class TestCanonicalize: + """Test _canonicalize function for deterministic ordering.""" + + def test_frozenset_ordering_is_deterministic(self): + """Frozensets should produce consistent canonical form regardless of iteration order.""" + # Create two frozensets with same content + fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)]) + fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)]) + + result1 = _canonicalize(fs1) + result2 = _canonicalize(fs2) + + assert result1 == result2 + + def test_nested_frozenset_ordering(self): + """Nested frozensets should also be deterministically ordered.""" + inner1 = frozenset([1, 2, 3]) + inner2 = frozenset([3, 2, 1]) + + fs1 = frozenset([("key", inner1)]) + fs2 = frozenset([("key", inner2)]) + + result1 = _canonicalize(fs1) + result2 = _canonicalize(fs2) + + assert result1 == result2 + + def test_dict_ordering(self): + """Dicts should be sorted by key.""" + d1 = {"z": 1, "a": 2, "m": 3} + d2 = {"a": 2, "m": 3, "z": 1} + + result1 = _canonicalize(d1) + result2 = _canonicalize(d2) + + assert result1 == result2 + + def test_tuple_preserved(self): + """Tuples should be marked and preserved.""" + t = (1, 2, 3) + result = _canonicalize(t) + + assert result[0] == "__tuple__" + + def test_list_preserved(self): + """Lists should be recursively canonicalized.""" + lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])] + result = _canonicalize(lst) + + # First element should be canonicalized dict + assert "__dict__" in result[0] + # Second element should be canonicalized frozenset + assert result[1][0] == "__frozenset__" + + def test_primitives_include_type(self): + """Primitive types should include type name for disambiguation.""" + assert _canonicalize(42) == ("int", 42) + assert _canonicalize(3.14) == ("float", 3.14) + assert _canonicalize("hello") == ("str", "hello") + assert _canonicalize(True) == ("bool", True) + assert _canonicalize(None) == ("NoneType", None) + + def test_int_and_str_distinguished(self): + """int 7 and str '7' must produce different canonical forms.""" + assert _canonicalize(7) != _canonicalize("7") + + def test_bytes_converted(self): + """Bytes should be converted to hex string.""" + b = b"\x00\xff" + result = _canonicalize(b) + + assert result[0] == "__bytes__" + assert result[1] == "00ff" + + def test_set_ordering(self): + """Sets should be sorted like frozensets.""" + s1 = {3, 1, 2} + s2 = {1, 2, 3} + + result1 = _canonicalize(s1) + result2 = _canonicalize(s2) + + assert result1 == result2 + assert result1[0] == "__set__" + + def test_unknown_type_raises(self): + """Unknown types should raise ValueError (fail-closed).""" + class CustomObj: + pass + with pytest.raises(ValueError): + _canonicalize(CustomObj()) + + def test_object_with_value_attr_raises(self): + """Objects with .value attribute (Unhashable-like) should raise ValueError.""" + class FakeUnhashable: + def __init__(self): + self.value = float('nan') + with pytest.raises(ValueError): + _canonicalize(FakeUnhashable()) + + +class TestSerializeCacheKey: + """Test _serialize_cache_key for deterministic hashing.""" + + def test_same_content_same_hash(self): + """Same content should produce same hash.""" + key1 = frozenset([("node_1", frozenset([("input", "value")]))]) + key2 = frozenset([("node_1", frozenset([("input", "value")]))]) + + hash1 = _serialize_cache_key(key1) + hash2 = _serialize_cache_key(key2) + + assert hash1 == hash2 + + def test_different_content_different_hash(self): + """Different content should produce different hash.""" + key1 = frozenset([("node_1", "value_a")]) + key2 = frozenset([("node_1", "value_b")]) + + hash1 = _serialize_cache_key(key1) + hash2 = _serialize_cache_key(key2) + + assert hash1 != hash2 + + def test_returns_hex_string(self): + """Should return hex string (SHA256 hex digest).""" + key = frozenset([("test", 123)]) + result = _serialize_cache_key(key) + + assert isinstance(result, str) + assert len(result) == 64 # SHA256 hex digest is 64 chars + + def test_complex_nested_structure(self): + """Complex nested structures should hash deterministically.""" + # Note: frozensets can only contain hashable types, so we use + # nested frozensets of tuples to represent dict-like structures + key = frozenset([ + ("node_1", frozenset([ + ("input_a", ("tuple", "value")), + ("input_b", frozenset([("nested", "dict")])), + ])), + ("node_2", frozenset([ + ("param", 42), + ])), + ]) + + # Hash twice to verify determinism + hash1 = _serialize_cache_key(key) + hash2 = _serialize_cache_key(key) + + assert hash1 == hash2 + + def test_dict_in_cache_key(self): + """Dicts passed directly to _serialize_cache_key should work.""" + key = {"node_1": {"input": "value"}, "node_2": 42} + + hash1 = _serialize_cache_key(key) + hash2 = _serialize_cache_key(key) + + assert hash1 == hash2 + assert isinstance(hash1, str) + assert len(hash1) == 64 + + def test_unknown_type_returns_none(self): + """Non-cacheable types should return None (fail-closed).""" + class CustomObj: + pass + assert _serialize_cache_key(CustomObj()) is None + + +class TestContainsSelfUnequal: + """Test _contains_self_unequal utility function.""" + + def test_nan_float_detected(self): + """NaN floats should be detected (not equal to itself).""" + assert _contains_self_unequal(float('nan')) is True + + def test_regular_float_not_detected(self): + """Regular floats are equal to themselves.""" + assert _contains_self_unequal(3.14) is False + assert _contains_self_unequal(0.0) is False + assert _contains_self_unequal(-1.5) is False + + def test_infinity_not_detected(self): + """Infinity is equal to itself.""" + assert _contains_self_unequal(float('inf')) is False + assert _contains_self_unequal(float('-inf')) is False + + def test_nan_in_list(self): + """NaN in list should be detected.""" + assert _contains_self_unequal([1, 2, float('nan'), 4]) is True + assert _contains_self_unequal([1, 2, 3, 4]) is False + + def test_nan_in_tuple(self): + """NaN in tuple should be detected.""" + assert _contains_self_unequal((1, float('nan'))) is True + assert _contains_self_unequal((1, 2, 3)) is False + + def test_nan_in_frozenset(self): + """NaN in frozenset should be detected.""" + assert _contains_self_unequal(frozenset([1, float('nan')])) is True + assert _contains_self_unequal(frozenset([1, 2, 3])) is False + + def test_nan_in_dict_value(self): + """NaN in dict value should be detected.""" + assert _contains_self_unequal({"key": float('nan')}) is True + assert _contains_self_unequal({"key": 42}) is False + + def test_nan_in_nested_structure(self): + """NaN in deeply nested structure should be detected.""" + nested = {"level1": [{"level2": (1, 2, float('nan'))}]} + assert _contains_self_unequal(nested) is True + + def test_non_numeric_types(self): + """Non-numeric types should not be self-unequal.""" + assert _contains_self_unequal("string") is False + assert _contains_self_unequal(None) is False + assert _contains_self_unequal(True) is False + + def test_object_with_nan_value_attribute(self): + """Objects wrapping NaN in .value should be detected.""" + class NanWrapper: + def __init__(self): + self.value = float('nan') + assert _contains_self_unequal(NanWrapper()) is True + + def test_custom_self_unequal_object(self): + """Custom objects where not (x == x) should be detected.""" + class NeverEqual: + def __eq__(self, other): + return False + assert _contains_self_unequal(NeverEqual()) is True + + +class TestEstimateValueSize: + """Test _estimate_value_size utility function.""" + + def test_empty_outputs(self): + """Empty outputs should have zero size.""" + value = CacheValue(outputs=[]) + assert _estimate_value_size(value) == 0 + + @pytest.mark.skipif( + not _torch_available(), + reason="PyTorch not available" + ) + def test_tensor_size_estimation(self): + """Tensor size should be estimated correctly.""" + import torch + + # 1000 float32 elements = 4000 bytes + tensor = torch.zeros(1000, dtype=torch.float32) + value = CacheValue(outputs=[[tensor]]) + + size = _estimate_value_size(value) + assert size == 4000 + + @pytest.mark.skipif( + not _torch_available(), + reason="PyTorch not available" + ) + def test_nested_tensor_in_dict(self): + """Tensors nested in dicts should be counted.""" + import torch + + tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes + value = CacheValue(outputs=[[{"samples": tensor}]]) + + size = _estimate_value_size(value) + assert size == 400 + + +class TestProviderRegistry: + """Test cache provider registration and retrieval.""" + + def setup_method(self): + """Clear providers before each test.""" + _clear_cache_providers() + + def teardown_method(self): + """Clear providers after each test.""" + _clear_cache_providers() + + def test_register_provider(self): + """Provider should be registered successfully.""" + provider = MockCacheProvider() + register_cache_provider(provider) + + assert _has_cache_providers() is True + providers = _get_cache_providers() + assert len(providers) == 1 + assert providers[0] is provider + + def test_unregister_provider(self): + """Provider should be unregistered successfully.""" + provider = MockCacheProvider() + register_cache_provider(provider) + unregister_cache_provider(provider) + + assert _has_cache_providers() is False + + def test_multiple_providers(self): + """Multiple providers can be registered.""" + provider1 = MockCacheProvider() + provider2 = MockCacheProvider() + + register_cache_provider(provider1) + register_cache_provider(provider2) + + providers = _get_cache_providers() + assert len(providers) == 2 + + def test_duplicate_registration_ignored(self): + """Registering same provider twice should be ignored.""" + provider = MockCacheProvider() + + register_cache_provider(provider) + register_cache_provider(provider) # Should be ignored + + providers = _get_cache_providers() + assert len(providers) == 1 + + def test_clear_providers(self): + """_clear_cache_providers should remove all providers.""" + provider1 = MockCacheProvider() + provider2 = MockCacheProvider() + + register_cache_provider(provider1) + register_cache_provider(provider2) + _clear_cache_providers() + + assert _has_cache_providers() is False + assert len(_get_cache_providers()) == 0 + + +class TestCacheContext: + """Test CacheContext dataclass.""" + + def test_context_creation(self): + """CacheContext should be created with all fields.""" + context = CacheContext( + node_id="node-456", + class_type="KSampler", + cache_key_hash="a" * 64, + ) + + assert context.node_id == "node-456" + assert context.class_type == "KSampler" + assert context.cache_key_hash == "a" * 64 + + +class TestCacheValue: + """Test CacheValue dataclass.""" + + def test_value_creation(self): + """CacheValue should be created with outputs.""" + outputs = [[{"samples": "tensor_data"}]] + value = CacheValue(outputs=outputs) + + assert value.outputs == outputs + + +class MockCacheProvider(CacheProvider): + """Mock cache provider for testing.""" + + def __init__(self): + self.lookups = [] + self.stores = [] + + async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: + self.lookups.append(context) + return None + + async def on_store(self, context: CacheContext, value: CacheValue) -> None: + self.stores.append((context, value))