From d1d53c14be8442fca19aae978e944edad1935d46 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 12 Mar 2026 17:21:23 -0700 Subject: [PATCH] Revert "feat: Add CacheProvider API for external distributed caching (#12056)" (#12912) This reverts commit af7b4a921d7abab7c852d7b5febb654be6e57eba. --- comfy_api/latest/__init__.py | 35 -- comfy_api/latest/_caching.py | 42 -- comfy_execution/cache_provider.py | 138 ------ comfy_execution/caching.py | 177 +------- comfy_execution/graph.py | 6 +- execution.py | 141 +++--- .../execution_test/test_cache_provider.py | 403 ------------------ 7 files changed, 83 insertions(+), 859 deletions(-) delete mode 100644 comfy_api/latest/_caching.py delete mode 100644 comfy_execution/cache_provider.py delete mode 100644 tests-unit/execution_test/test_cache_provider.py diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 04973fea0..f2399422b 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -25,7 +25,6 @@ class ComfyAPI_latest(ComfyAPIBase): super().__init__() self.node_replacement = self.NodeReplacement() self.execution = self.Execution() - self.caching = self.Caching() class NodeReplacement(ProxiedSingleton): async def register(self, node_replace: io.NodeReplace) -> None: @@ -85,36 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase): image=to_display, ) - class Caching(ProxiedSingleton): - """ - External cache provider API for sharing cached node outputs - across ComfyUI instances. - - Example:: - - from comfy_api.latest import Caching - - class MyCacheProvider(Caching.CacheProvider): - async def on_lookup(self, context): - ... # check external storage - - async def on_store(self, context, value): - ... # store to external storage - - Caching.register_provider(MyCacheProvider()) - """ - from ._caching import CacheProvider, CacheContext, CacheValue - - async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None: - """Register an external cache provider. Providers are called in registration order.""" - from comfy_execution.cache_provider import register_cache_provider - register_cache_provider(provider) - - async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None: - """Unregister a previously registered cache provider.""" - from comfy_execution.cache_provider import unregister_cache_provider - unregister_cache_provider(provider) - class ComfyExtension(ABC): async def on_load(self) -> None: """ @@ -147,9 +116,6 @@ class Types: VOXEL = VOXEL File3D = File3D - -Caching = ComfyAPI_latest.Caching - ComfyAPI = ComfyAPI_latest # Create a synchronous version of the API @@ -169,7 +135,6 @@ __all__ = [ "Input", "InputImpl", "Types", - "Caching", "ComfyExtension", "io", "IO", diff --git a/comfy_api/latest/_caching.py b/comfy_api/latest/_caching.py deleted file mode 100644 index 30c8848cd..000000000 --- a/comfy_api/latest/_caching.py +++ /dev/null @@ -1,42 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional -from dataclasses import dataclass - - -@dataclass -class CacheContext: - node_id: str - class_type: str - cache_key_hash: str # SHA256 hex digest - - -@dataclass -class CacheValue: - outputs: list - ui: dict = None - - -class CacheProvider(ABC): - """Abstract base class for external cache providers. - Exceptions from provider methods are caught by the caller and never break execution. - """ - - @abstractmethod - async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: - """Called on local cache miss. Return CacheValue if found, None otherwise.""" - pass - - @abstractmethod - async def on_store(self, context: CacheContext, value: CacheValue) -> None: - """Called after local store. Dispatched via asyncio.create_task.""" - pass - - def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool: - """Return False to skip external caching for this node. Default: True.""" - return True - - def on_prompt_start(self, prompt_id: str) -> None: - pass - - def on_prompt_end(self, prompt_id: str) -> None: - pass diff --git a/comfy_execution/cache_provider.py b/comfy_execution/cache_provider.py deleted file mode 100644 index d455d08e8..000000000 --- a/comfy_execution/cache_provider.py +++ /dev/null @@ -1,138 +0,0 @@ -from typing import Any, Optional, Tuple, List -import hashlib -import json -import logging -import threading - -# Public types — source of truth is comfy_api.latest._caching -from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported) - -_logger = logging.getLogger(__name__) - - -_providers: List[CacheProvider] = [] -_providers_lock = threading.Lock() -_providers_snapshot: Tuple[CacheProvider, ...] = () - - -def register_cache_provider(provider: CacheProvider) -> None: - """Register an external cache provider. Providers are called in registration order.""" - global _providers_snapshot - with _providers_lock: - if provider in _providers: - _logger.warning(f"Provider {provider.__class__.__name__} already registered") - return - _providers.append(provider) - _providers_snapshot = tuple(_providers) - _logger.debug(f"Registered cache provider: {provider.__class__.__name__}") - - -def unregister_cache_provider(provider: CacheProvider) -> None: - global _providers_snapshot - with _providers_lock: - try: - _providers.remove(provider) - _providers_snapshot = tuple(_providers) - _logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}") - except ValueError: - _logger.warning(f"Provider {provider.__class__.__name__} was not registered") - - -def _get_cache_providers() -> Tuple[CacheProvider, ...]: - return _providers_snapshot - - -def _has_cache_providers() -> bool: - return bool(_providers_snapshot) - - -def _clear_cache_providers() -> None: - global _providers_snapshot - with _providers_lock: - _providers.clear() - _providers_snapshot = () - - -def _canonicalize(obj: Any) -> Any: - # Convert to canonical JSON-serializable form with deterministic ordering. - # Frozensets have non-deterministic iteration order between Python sessions. - # Raises ValueError for non-cacheable types (Unhashable, unknown) so that - # _serialize_cache_key returns None and external caching is skipped. - if isinstance(obj, frozenset): - return ("__frozenset__", sorted( - [_canonicalize(item) for item in obj], - key=lambda x: json.dumps(x, sort_keys=True) - )) - elif isinstance(obj, set): - return ("__set__", sorted( - [_canonicalize(item) for item in obj], - key=lambda x: json.dumps(x, sort_keys=True) - )) - elif isinstance(obj, tuple): - return ("__tuple__", [_canonicalize(item) for item in obj]) - elif isinstance(obj, list): - return [_canonicalize(item) for item in obj] - elif isinstance(obj, dict): - return {"__dict__": sorted( - [[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()], - key=lambda x: json.dumps(x, sort_keys=True) - )} - elif isinstance(obj, (int, float, str, bool, type(None))): - return (type(obj).__name__, obj) - elif isinstance(obj, bytes): - return ("__bytes__", obj.hex()) - else: - raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}") - - -def _serialize_cache_key(cache_key: Any) -> Optional[str]: - # Returns deterministic SHA256 hex digest, or None on failure. - # Uses JSON (not pickle) because pickle is non-deterministic across sessions. - try: - canonical = _canonicalize(cache_key) - json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':')) - return hashlib.sha256(json_str.encode('utf-8')).hexdigest() - except Exception as e: - _logger.warning(f"Failed to serialize cache key: {e}") - return None - - -def _contains_self_unequal(obj: Any) -> bool: - # Local cache matches by ==. Values where not (x == x) (NaN, etc.) will - # never hit locally, but serialized form would match externally. Skip these. - try: - if not (obj == obj): - return True - except Exception: - return True - if isinstance(obj, (frozenset, tuple, list, set)): - return any(_contains_self_unequal(item) for item in obj) - if isinstance(obj, dict): - return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items()) - if hasattr(obj, 'value'): - return _contains_self_unequal(obj.value) - return False - - -def _estimate_value_size(value: CacheValue) -> int: - try: - import torch - except ImportError: - return 0 - - total = 0 - - def estimate(obj): - nonlocal total - if isinstance(obj, torch.Tensor): - total += obj.numel() * obj.element_size() - elif isinstance(obj, dict): - for v in obj.values(): - estimate(v) - elif isinstance(obj, (list, tuple)): - for item in obj: - estimate(item) - - for output in value.outputs: - estimate(output) - return total diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 750bddf2e..326a279fc 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -1,4 +1,3 @@ -import asyncio import bisect import gc import itertools @@ -155,7 +154,6 @@ class BasicCache: self.cache_key_set: CacheKeySet self.cache = {} self.subcaches = {} - self._pending_store_tasks: set = set() async def set_prompt(self, dynprompt, node_ids, is_changed_cache): self.dynprompt = dynprompt @@ -198,134 +196,18 @@ class BasicCache: def poll(self, **kwargs): pass - def get_local(self, node_id): + def _set_immediate(self, node_id, value): + assert self.initialized + cache_key = self.cache_key_set.get_data_key(node_id) + self.cache[cache_key] = value + + def _get_immediate(self, node_id): if not self.initialized: return None cache_key = self.cache_key_set.get_data_key(node_id) if cache_key in self.cache: return self.cache[cache_key] - return None - - def set_local(self, node_id, value): - assert self.initialized - cache_key = self.cache_key_set.get_data_key(node_id) - self.cache[cache_key] = value - - async def _set_immediate(self, node_id, value): - assert self.initialized - cache_key = self.cache_key_set.get_data_key(node_id) - self.cache[cache_key] = value - - await self._notify_providers_store(node_id, cache_key, value) - - async def _get_immediate(self, node_id): - if not self.initialized: - return None - cache_key = self.cache_key_set.get_data_key(node_id) - - if cache_key in self.cache: - return self.cache[cache_key] - - external_result = await self._check_providers_lookup(node_id, cache_key) - if external_result is not None: - self.cache[cache_key] = external_result - return external_result - - return None - - async def _notify_providers_store(self, node_id, cache_key, value): - from comfy_execution.cache_provider import ( - _has_cache_providers, _get_cache_providers, - CacheValue, _contains_self_unequal, _logger - ) - - if not _has_cache_providers(): - return - if not self._is_external_cacheable_value(value): - return - if _contains_self_unequal(cache_key): - return - - context = self._build_context(node_id, cache_key) - if context is None: - return - cache_value = CacheValue(outputs=value.outputs, ui=value.ui) - - for provider in _get_cache_providers(): - try: - if provider.should_cache(context, cache_value): - task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value)) - self._pending_store_tasks.add(task) - task.add_done_callback(self._pending_store_tasks.discard) - except Exception as e: - _logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}") - - @staticmethod - async def _safe_provider_store(provider, context, cache_value): - from comfy_execution.cache_provider import _logger - try: - await provider.on_store(context, cache_value) - except Exception as e: - _logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}") - - async def _check_providers_lookup(self, node_id, cache_key): - from comfy_execution.cache_provider import ( - _has_cache_providers, _get_cache_providers, - CacheValue, _contains_self_unequal, _logger - ) - - if not _has_cache_providers(): - return None - if _contains_self_unequal(cache_key): - return None - - context = self._build_context(node_id, cache_key) - if context is None: - return None - - for provider in _get_cache_providers(): - try: - if not provider.should_cache(context): - continue - result = await provider.on_lookup(context) - if result is not None: - if not isinstance(result, CacheValue): - _logger.warning(f"Provider {provider.__class__.__name__} returned invalid type") - continue - if not isinstance(result.outputs, (list, tuple)): - _logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs") - continue - from execution import CacheEntry - return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs)) - except Exception as e: - _logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}") - - return None - - def _is_external_cacheable_value(self, value): - return hasattr(value, 'outputs') and hasattr(value, 'ui') - - def _get_class_type(self, node_id): - if not self.initialized or not self.dynprompt: - return '' - try: - return self.dynprompt.get_node(node_id).get('class_type', '') - except Exception: - return '' - - def _build_context(self, node_id, cache_key): - from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger - try: - cache_key_hash = _serialize_cache_key(cache_key) - if cache_key_hash is None: - return None - return CacheContext( - node_id=node_id, - class_type=self._get_class_type(node_id), - cache_key_hash=cache_key_hash, - ) - except Exception as e: - _logger.warning(f"Failed to build cache context for node {node_id}: {e}") + else: return None async def _ensure_subcache(self, node_id, children_ids): @@ -375,27 +257,16 @@ class HierarchicalCache(BasicCache): return None return cache - async def get(self, node_id): + def get(self, node_id): cache = self._get_cache_for(node_id) if cache is None: return None - return await cache._get_immediate(node_id) + return cache._get_immediate(node_id) - def get_local(self, node_id): - cache = self._get_cache_for(node_id) - if cache is None: - return None - return BasicCache.get_local(cache, node_id) - - async def set(self, node_id, value): + def set(self, node_id, value): cache = self._get_cache_for(node_id) assert cache is not None - await cache._set_immediate(node_id, value) - - def set_local(self, node_id, value): - cache = self._get_cache_for(node_id) - assert cache is not None - BasicCache.set_local(cache, node_id, value) + cache._set_immediate(node_id, value) async def ensure_subcache_for(self, node_id, children_ids): cache = self._get_cache_for(node_id) @@ -416,16 +287,10 @@ class NullCache: def poll(self, **kwargs): pass - async def get(self, node_id): + def get(self, node_id): return None - def get_local(self, node_id): - return None - - async def set(self, node_id, value): - pass - - def set_local(self, node_id, value): + def set(self, node_id, value): pass async def ensure_subcache_for(self, node_id, children_ids): @@ -457,18 +322,18 @@ class LRUCache(BasicCache): del self.children[key] self._clean_subcaches() - async def get(self, node_id): + def get(self, node_id): self._mark_used(node_id) - return await self._get_immediate(node_id) + return self._get_immediate(node_id) def _mark_used(self, node_id): cache_key = self.cache_key_set.get_data_key(node_id) if cache_key is not None: self.used_generation[cache_key] = self.generation - async def set(self, node_id, value): + def set(self, node_id, value): self._mark_used(node_id) - return await self._set_immediate(node_id, value) + return self._set_immediate(node_id, value) async def ensure_subcache_for(self, node_id, children_ids): # Just uses subcaches for tracking 'live' nodes @@ -508,13 +373,13 @@ class RAMPressureCache(LRUCache): def clean_unused(self): self._clean_subcaches() - async def set(self, node_id, value): + def set(self, node_id, value): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() - await super().set(node_id, value) + super().set(node_id, value) - async def get(self, node_id): + def get(self, node_id): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() - return await super().get(node_id) + return super().get(node_id) def poll(self, ram_headroom): def _ram_gb(): diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index c47f3c79b..9d170b16e 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -204,12 +204,12 @@ class ExecutionList(TopologicalSort): self.execution_cache_listeners = {} def is_cached(self, node_id): - return self.output_cache.get_local(node_id) is not None + return self.output_cache.get(node_id) is not None def cache_link(self, from_node_id, to_node_id): if to_node_id not in self.execution_cache: self.execution_cache[to_node_id] = {} - self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id) + self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) if from_node_id not in self.execution_cache_listeners: self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id].add(to_node_id) @@ -221,7 +221,7 @@ class ExecutionList(TopologicalSort): if value is None: return None #Write back to the main cache on touch. - self.output_cache.set_local(from_node_id, value) + self.output_cache.set(from_node_id, value) return value def cache_update(self, node_id, value): diff --git a/execution.py b/execution.py index a8e8fc59f..a7791efed 100644 --- a/execution.py +++ b/execution.py @@ -40,7 +40,6 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a from comfy_execution.utils import CurrentNodeContext from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.latest import io, _io -from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger class ExecutionResult(Enum): @@ -419,7 +418,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - cached = await caches.outputs.get(unique_id) + cached = caches.outputs.get(unique_id) if cached is not None: if server.client_id is not None: cached_ui = cached.ui or {} @@ -475,10 +474,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) - obj = await caches.objects.get(unique_id) + obj = caches.objects.get(unique_id) if obj is None: obj = class_def() - await caches.objects.set(unique_id, obj) + caches.objects.set(unique_id, obj) if issubclass(class_def, _ComfyNodeInternal): lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None @@ -589,7 +588,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data) execution_list.cache_update(unique_id, cache_entry) - await caches.outputs.set(unique_id, cache_entry) + caches.outputs.set(unique_id, cache_entry) except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") @@ -685,19 +684,6 @@ class PromptExecutor: } self.add_message("execution_error", mes, broadcast=False) - def _notify_prompt_lifecycle(self, event: str, prompt_id: str): - if not _has_cache_providers(): - return - - for provider in _get_cache_providers(): - try: - if event == "start": - provider.on_prompt_start(prompt_id) - elif event == "end": - provider.on_prompt_end(prompt_id) - except Exception as e: - _cache_logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}") - def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) @@ -714,75 +700,66 @@ class PromptExecutor: self.status_messages = [] self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) - self._notify_prompt_lifecycle("start", prompt_id) + with torch.inference_mode(): + dynamic_prompt = DynamicPrompt(prompt) + reset_progress_state(prompt_id, dynamic_prompt) + add_progress_handler(WebUIProgressHandler(self.server)) + is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) + for cache in self.caches.all: + await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) + cache.clean_unused() - try: - with torch.inference_mode(): - dynamic_prompt = DynamicPrompt(prompt) - reset_progress_state(prompt_id, dynamic_prompt) - add_progress_handler(WebUIProgressHandler(self.server)) - is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) - for cache in self.caches.all: - await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) - cache.clean_unused() + cached_nodes = [] + for node_id in prompt: + if self.caches.outputs.get(node_id) is not None: + cached_nodes.append(node_id) - node_ids = list(prompt.keys()) - cache_results = await asyncio.gather( - *(self.caches.outputs.get(node_id) for node_id in node_ids) - ) - cached_nodes = [ - node_id for node_id, result in zip(node_ids, cache_results) - if result is not None - ] + comfy.model_management.cleanup_models_gc() + self.add_message("execution_cached", + { "nodes": cached_nodes, "prompt_id": prompt_id}, + broadcast=False) + pending_subgraph_results = {} + pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results + ui_node_outputs = {} + executed = set() + execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) + current_outputs = self.caches.outputs.all_node_ids() + for node_id in list(execute_outputs): + execution_list.add_node(node_id) - comfy.model_management.cleanup_models_gc() - self.add_message("execution_cached", - { "nodes": cached_nodes, "prompt_id": prompt_id}, - broadcast=False) - pending_subgraph_results = {} - pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results - ui_node_outputs = {} - executed = set() - execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) - current_outputs = self.caches.outputs.all_node_ids() - for node_id in list(execute_outputs): - execution_list.add_node(node_id) + while not execution_list.is_empty(): + node_id, error, ex = await execution_list.stage_node_execution() + if error is not None: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break - while not execution_list.is_empty(): - node_id, error, ex = await execution_list.stage_node_execution() - if error is not None: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) - break + assert node_id is not None, "Node ID should not be None at this point" + result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) + self.success = result != ExecutionResult.FAILURE + if result == ExecutionResult.FAILURE: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break + elif result == ExecutionResult.PENDING: + execution_list.unstage_node_execution() + else: # result == ExecutionResult.SUCCESS: + execution_list.complete_node_execution() + self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) + else: + # Only execute when the while-loop ends without break + self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) - assert node_id is not None, "Node ID should not be None at this point" - result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) - self.success = result != ExecutionResult.FAILURE - if result == ExecutionResult.FAILURE: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) - break - elif result == ExecutionResult.PENDING: - execution_list.unstage_node_execution() - else: # result == ExecutionResult.SUCCESS: - execution_list.complete_node_execution() - self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) - else: - # Only execute when the while-loop ends without break - self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) - - ui_outputs = {} - meta_outputs = {} - for node_id, ui_info in ui_node_outputs.items(): - ui_outputs[node_id] = ui_info["output"] - meta_outputs[node_id] = ui_info["meta"] - self.history_result = { - "outputs": ui_outputs, - "meta": meta_outputs, - } - self.server.last_node_id = None - if comfy.model_management.DISABLE_SMART_MEMORY: - comfy.model_management.unload_all_models() - finally: - self._notify_prompt_lifecycle("end", prompt_id) + ui_outputs = {} + meta_outputs = {} + for node_id, ui_info in ui_node_outputs.items(): + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] + self.history_result = { + "outputs": ui_outputs, + "meta": meta_outputs, + } + self.server.last_node_id = None + if comfy.model_management.DISABLE_SMART_MEMORY: + comfy.model_management.unload_all_models() async def validate_inputs(prompt_id, prompt, item, validated): diff --git a/tests-unit/execution_test/test_cache_provider.py b/tests-unit/execution_test/test_cache_provider.py deleted file mode 100644 index ac3814746..000000000 --- a/tests-unit/execution_test/test_cache_provider.py +++ /dev/null @@ -1,403 +0,0 @@ -"""Tests for external cache provider API.""" - -import importlib.util -import pytest -from typing import Optional - - -def _torch_available() -> bool: - """Check if PyTorch is available.""" - return importlib.util.find_spec("torch") is not None - - -from comfy_execution.cache_provider import ( - CacheProvider, - CacheContext, - CacheValue, - register_cache_provider, - unregister_cache_provider, - _get_cache_providers, - _has_cache_providers, - _clear_cache_providers, - _serialize_cache_key, - _contains_self_unequal, - _estimate_value_size, - _canonicalize, -) - - -class TestCanonicalize: - """Test _canonicalize function for deterministic ordering.""" - - def test_frozenset_ordering_is_deterministic(self): - """Frozensets should produce consistent canonical form regardless of iteration order.""" - # Create two frozensets with same content - fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)]) - fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)]) - - result1 = _canonicalize(fs1) - result2 = _canonicalize(fs2) - - assert result1 == result2 - - def test_nested_frozenset_ordering(self): - """Nested frozensets should also be deterministically ordered.""" - inner1 = frozenset([1, 2, 3]) - inner2 = frozenset([3, 2, 1]) - - fs1 = frozenset([("key", inner1)]) - fs2 = frozenset([("key", inner2)]) - - result1 = _canonicalize(fs1) - result2 = _canonicalize(fs2) - - assert result1 == result2 - - def test_dict_ordering(self): - """Dicts should be sorted by key.""" - d1 = {"z": 1, "a": 2, "m": 3} - d2 = {"a": 2, "m": 3, "z": 1} - - result1 = _canonicalize(d1) - result2 = _canonicalize(d2) - - assert result1 == result2 - - def test_tuple_preserved(self): - """Tuples should be marked and preserved.""" - t = (1, 2, 3) - result = _canonicalize(t) - - assert result[0] == "__tuple__" - - def test_list_preserved(self): - """Lists should be recursively canonicalized.""" - lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])] - result = _canonicalize(lst) - - # First element should be canonicalized dict - assert "__dict__" in result[0] - # Second element should be canonicalized frozenset - assert result[1][0] == "__frozenset__" - - def test_primitives_include_type(self): - """Primitive types should include type name for disambiguation.""" - assert _canonicalize(42) == ("int", 42) - assert _canonicalize(3.14) == ("float", 3.14) - assert _canonicalize("hello") == ("str", "hello") - assert _canonicalize(True) == ("bool", True) - assert _canonicalize(None) == ("NoneType", None) - - def test_int_and_str_distinguished(self): - """int 7 and str '7' must produce different canonical forms.""" - assert _canonicalize(7) != _canonicalize("7") - - def test_bytes_converted(self): - """Bytes should be converted to hex string.""" - b = b"\x00\xff" - result = _canonicalize(b) - - assert result[0] == "__bytes__" - assert result[1] == "00ff" - - def test_set_ordering(self): - """Sets should be sorted like frozensets.""" - s1 = {3, 1, 2} - s2 = {1, 2, 3} - - result1 = _canonicalize(s1) - result2 = _canonicalize(s2) - - assert result1 == result2 - assert result1[0] == "__set__" - - def test_unknown_type_raises(self): - """Unknown types should raise ValueError (fail-closed).""" - class CustomObj: - pass - with pytest.raises(ValueError): - _canonicalize(CustomObj()) - - def test_object_with_value_attr_raises(self): - """Objects with .value attribute (Unhashable-like) should raise ValueError.""" - class FakeUnhashable: - def __init__(self): - self.value = float('nan') - with pytest.raises(ValueError): - _canonicalize(FakeUnhashable()) - - -class TestSerializeCacheKey: - """Test _serialize_cache_key for deterministic hashing.""" - - def test_same_content_same_hash(self): - """Same content should produce same hash.""" - key1 = frozenset([("node_1", frozenset([("input", "value")]))]) - key2 = frozenset([("node_1", frozenset([("input", "value")]))]) - - hash1 = _serialize_cache_key(key1) - hash2 = _serialize_cache_key(key2) - - assert hash1 == hash2 - - def test_different_content_different_hash(self): - """Different content should produce different hash.""" - key1 = frozenset([("node_1", "value_a")]) - key2 = frozenset([("node_1", "value_b")]) - - hash1 = _serialize_cache_key(key1) - hash2 = _serialize_cache_key(key2) - - assert hash1 != hash2 - - def test_returns_hex_string(self): - """Should return hex string (SHA256 hex digest).""" - key = frozenset([("test", 123)]) - result = _serialize_cache_key(key) - - assert isinstance(result, str) - assert len(result) == 64 # SHA256 hex digest is 64 chars - - def test_complex_nested_structure(self): - """Complex nested structures should hash deterministically.""" - # Note: frozensets can only contain hashable types, so we use - # nested frozensets of tuples to represent dict-like structures - key = frozenset([ - ("node_1", frozenset([ - ("input_a", ("tuple", "value")), - ("input_b", frozenset([("nested", "dict")])), - ])), - ("node_2", frozenset([ - ("param", 42), - ])), - ]) - - # Hash twice to verify determinism - hash1 = _serialize_cache_key(key) - hash2 = _serialize_cache_key(key) - - assert hash1 == hash2 - - def test_dict_in_cache_key(self): - """Dicts passed directly to _serialize_cache_key should work.""" - key = {"node_1": {"input": "value"}, "node_2": 42} - - hash1 = _serialize_cache_key(key) - hash2 = _serialize_cache_key(key) - - assert hash1 == hash2 - assert isinstance(hash1, str) - assert len(hash1) == 64 - - def test_unknown_type_returns_none(self): - """Non-cacheable types should return None (fail-closed).""" - class CustomObj: - pass - assert _serialize_cache_key(CustomObj()) is None - - -class TestContainsSelfUnequal: - """Test _contains_self_unequal utility function.""" - - def test_nan_float_detected(self): - """NaN floats should be detected (not equal to itself).""" - assert _contains_self_unequal(float('nan')) is True - - def test_regular_float_not_detected(self): - """Regular floats are equal to themselves.""" - assert _contains_self_unequal(3.14) is False - assert _contains_self_unequal(0.0) is False - assert _contains_self_unequal(-1.5) is False - - def test_infinity_not_detected(self): - """Infinity is equal to itself.""" - assert _contains_self_unequal(float('inf')) is False - assert _contains_self_unequal(float('-inf')) is False - - def test_nan_in_list(self): - """NaN in list should be detected.""" - assert _contains_self_unequal([1, 2, float('nan'), 4]) is True - assert _contains_self_unequal([1, 2, 3, 4]) is False - - def test_nan_in_tuple(self): - """NaN in tuple should be detected.""" - assert _contains_self_unequal((1, float('nan'))) is True - assert _contains_self_unequal((1, 2, 3)) is False - - def test_nan_in_frozenset(self): - """NaN in frozenset should be detected.""" - assert _contains_self_unequal(frozenset([1, float('nan')])) is True - assert _contains_self_unequal(frozenset([1, 2, 3])) is False - - def test_nan_in_dict_value(self): - """NaN in dict value should be detected.""" - assert _contains_self_unequal({"key": float('nan')}) is True - assert _contains_self_unequal({"key": 42}) is False - - def test_nan_in_nested_structure(self): - """NaN in deeply nested structure should be detected.""" - nested = {"level1": [{"level2": (1, 2, float('nan'))}]} - assert _contains_self_unequal(nested) is True - - def test_non_numeric_types(self): - """Non-numeric types should not be self-unequal.""" - assert _contains_self_unequal("string") is False - assert _contains_self_unequal(None) is False - assert _contains_self_unequal(True) is False - - def test_object_with_nan_value_attribute(self): - """Objects wrapping NaN in .value should be detected.""" - class NanWrapper: - def __init__(self): - self.value = float('nan') - assert _contains_self_unequal(NanWrapper()) is True - - def test_custom_self_unequal_object(self): - """Custom objects where not (x == x) should be detected.""" - class NeverEqual: - def __eq__(self, other): - return False - assert _contains_self_unequal(NeverEqual()) is True - - -class TestEstimateValueSize: - """Test _estimate_value_size utility function.""" - - def test_empty_outputs(self): - """Empty outputs should have zero size.""" - value = CacheValue(outputs=[]) - assert _estimate_value_size(value) == 0 - - @pytest.mark.skipif( - not _torch_available(), - reason="PyTorch not available" - ) - def test_tensor_size_estimation(self): - """Tensor size should be estimated correctly.""" - import torch - - # 1000 float32 elements = 4000 bytes - tensor = torch.zeros(1000, dtype=torch.float32) - value = CacheValue(outputs=[[tensor]]) - - size = _estimate_value_size(value) - assert size == 4000 - - @pytest.mark.skipif( - not _torch_available(), - reason="PyTorch not available" - ) - def test_nested_tensor_in_dict(self): - """Tensors nested in dicts should be counted.""" - import torch - - tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes - value = CacheValue(outputs=[[{"samples": tensor}]]) - - size = _estimate_value_size(value) - assert size == 400 - - -class TestProviderRegistry: - """Test cache provider registration and retrieval.""" - - def setup_method(self): - """Clear providers before each test.""" - _clear_cache_providers() - - def teardown_method(self): - """Clear providers after each test.""" - _clear_cache_providers() - - def test_register_provider(self): - """Provider should be registered successfully.""" - provider = MockCacheProvider() - register_cache_provider(provider) - - assert _has_cache_providers() is True - providers = _get_cache_providers() - assert len(providers) == 1 - assert providers[0] is provider - - def test_unregister_provider(self): - """Provider should be unregistered successfully.""" - provider = MockCacheProvider() - register_cache_provider(provider) - unregister_cache_provider(provider) - - assert _has_cache_providers() is False - - def test_multiple_providers(self): - """Multiple providers can be registered.""" - provider1 = MockCacheProvider() - provider2 = MockCacheProvider() - - register_cache_provider(provider1) - register_cache_provider(provider2) - - providers = _get_cache_providers() - assert len(providers) == 2 - - def test_duplicate_registration_ignored(self): - """Registering same provider twice should be ignored.""" - provider = MockCacheProvider() - - register_cache_provider(provider) - register_cache_provider(provider) # Should be ignored - - providers = _get_cache_providers() - assert len(providers) == 1 - - def test_clear_providers(self): - """_clear_cache_providers should remove all providers.""" - provider1 = MockCacheProvider() - provider2 = MockCacheProvider() - - register_cache_provider(provider1) - register_cache_provider(provider2) - _clear_cache_providers() - - assert _has_cache_providers() is False - assert len(_get_cache_providers()) == 0 - - -class TestCacheContext: - """Test CacheContext dataclass.""" - - def test_context_creation(self): - """CacheContext should be created with all fields.""" - context = CacheContext( - node_id="node-456", - class_type="KSampler", - cache_key_hash="a" * 64, - ) - - assert context.node_id == "node-456" - assert context.class_type == "KSampler" - assert context.cache_key_hash == "a" * 64 - - -class TestCacheValue: - """Test CacheValue dataclass.""" - - def test_value_creation(self): - """CacheValue should be created with outputs.""" - outputs = [[{"samples": "tensor_data"}]] - value = CacheValue(outputs=outputs) - - assert value.outputs == outputs - - -class MockCacheProvider(CacheProvider): - """Mock cache provider for testing.""" - - def __init__(self): - self.lookups = [] - self.stores = [] - - async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: - self.lookups.append(context) - return None - - async def on_store(self, context: CacheContext, value: CacheValue) -> None: - self.stores.append((context, value))