From 6540aa040027d89a00fda36873bb4f58f5198619 Mon Sep 17 00:00:00 2001 From: Deep Mehta Date: Mon, 19 Jan 2026 16:43:13 +0530 Subject: [PATCH] feat: Add CacheProvider API for external distributed caching MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces a public API for external cache providers, enabling distributed caching across multiple ComfyUI instances (e.g., Kubernetes pods). New files: - comfy_execution/cache_provider.py: CacheProvider ABC, CacheContext/CacheValue dataclasses, thread-safe provider registry, serialization utilities Modified files: - comfy_execution/caching.py: Add provider hooks to BasicCache (_notify_providers_store, _check_providers_lookup), subcache exclusion, prompt ID propagation - execution.py: Add prompt lifecycle hooks (on_prompt_start/on_prompt_end) to PromptExecutor, set _current_prompt_id on caches Key features: - Local-first caching (check local before external for performance) - NaN detection to prevent incorrect external cache hits - Subcache exclusion (ephemeral subgraph results not cached externally) - Thread-safe provider snapshot caching - Graceful error handling (provider errors logged, never break execution) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- comfy_execution/cache_provider.py | 267 ++++++++++++++++++++++++++++++ comfy_execution/caching.py | 109 +++++++++++- execution.py | 137 +++++++++------ 3 files changed, 457 insertions(+), 56 deletions(-) create mode 100644 comfy_execution/cache_provider.py diff --git a/comfy_execution/cache_provider.py b/comfy_execution/cache_provider.py new file mode 100644 index 000000000..8c8d44620 --- /dev/null +++ b/comfy_execution/cache_provider.py @@ -0,0 +1,267 @@ +""" +External Cache Provider API for distributed caching. + +This module provides a public API for external cache providers, enabling +distributed caching across multiple ComfyUI instances (e.g., Kubernetes pods). + +Example usage: + from comfy_execution.cache_provider import ( + CacheProvider, CacheContext, CacheValue, register_cache_provider + ) + + class MyRedisProvider(CacheProvider): + def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: + # Check Redis/GCS for cached result + ... + + def on_store(self, context: CacheContext, value: CacheValue) -> None: + # Store to Redis/GCS (can be async internally) + ... + + register_cache_provider(MyRedisProvider()) +""" + +from abc import ABC, abstractmethod +from typing import Any, Optional, Tuple, List +from dataclasses import dataclass +import logging +import threading +import hashlib +import pickle +import math + +logger = logging.getLogger(__name__) + + +# ============================================================ +# Data Classes +# ============================================================ + +@dataclass +class CacheContext: + """Context passed to provider methods.""" + prompt_id: str # Current prompt execution ID + node_id: str # Node being cached + class_type: str # Node class type (e.g., "KSampler") + cache_key: Any # Raw cache key (frozenset structure) + cache_key_bytes: bytes # SHA256 hash for external storage key + + +@dataclass +class CacheValue: + """ + Value stored/retrieved from external cache. + + Note: UI data is intentionally excluded - it contains pod-local + file paths that aren't portable across instances. + """ + outputs: list # The tensor/value outputs + + +# ============================================================ +# Provider Interface +# ============================================================ + +class CacheProvider(ABC): + """ + Abstract base class for external cache providers. + + Thread Safety: + Providers may be called from multiple threads. Implementations + must be thread-safe. + + Error Handling: + All methods are wrapped in try/except by the caller. Exceptions + are logged but never propagate to break execution. + + Performance Guidelines: + - on_lookup: Should complete in <500ms (including network) + - on_store: Can be async internally (fire-and-forget) + - should_cache: Should be fast (<1ms), called frequently + """ + + @abstractmethod + def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: + """ + Check external storage for cached result. + + Called AFTER local cache miss (local-first for performance). + + Returns: + CacheValue if found externally, None otherwise. + + Important: + - Return None on any error (don't raise) + - Validate data integrity before returning + """ + pass + + @abstractmethod + def on_store(self, context: CacheContext, value: CacheValue) -> None: + """ + Store value to external cache. + + Called AFTER value is stored in local cache. + + Important: + - Can be fire-and-forget (async internally) + - Should never block execution + - Handle serialization failures gracefully + """ + pass + + def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool: + """ + Filter which nodes should be externally cached. + + Called before on_lookup (value=None) and on_store (value provided). + Return False to skip external caching for this node. + + Common filters: + - By class_type: Only expensive nodes (KSampler, VAEDecode) + - By size: Skip small values (< 1MB) + + Default: Returns True (cache everything). + """ + return True + + def on_prompt_start(self, prompt_id: str) -> None: + """Called when prompt execution begins. Optional.""" + pass + + def on_prompt_end(self, prompt_id: str) -> None: + """Called when prompt execution ends. Optional.""" + pass + + +# ============================================================ +# Provider Registry +# ============================================================ + +_providers: List[CacheProvider] = [] +_providers_lock = threading.Lock() +_providers_snapshot: Optional[Tuple[CacheProvider, ...]] = None + + +def register_cache_provider(provider: CacheProvider) -> None: + """ + Register an external cache provider. + + Providers are called in registration order. First provider to return + a result from on_lookup wins. + """ + global _providers_snapshot + with _providers_lock: + if provider in _providers: + logger.warning(f"Provider {provider.__class__.__name__} already registered") + return + _providers.append(provider) + _providers_snapshot = None # Invalidate cache + logger.info(f"Registered cache provider: {provider.__class__.__name__}") + + +def unregister_cache_provider(provider: CacheProvider) -> None: + """Remove a previously registered provider.""" + global _providers_snapshot + with _providers_lock: + try: + _providers.remove(provider) + _providers_snapshot = None + logger.info(f"Unregistered cache provider: {provider.__class__.__name__}") + except ValueError: + logger.warning(f"Provider {provider.__class__.__name__} was not registered") + + +def get_cache_providers() -> Tuple[CacheProvider, ...]: + """Get registered providers (cached for performance).""" + global _providers_snapshot + snapshot = _providers_snapshot + if snapshot is not None: + return snapshot + with _providers_lock: + if _providers_snapshot is not None: + return _providers_snapshot + _providers_snapshot = tuple(_providers) + return _providers_snapshot + + +def has_cache_providers() -> bool: + """Fast check if any providers registered (no lock).""" + return bool(_providers) + + +def clear_cache_providers() -> None: + """Remove all providers. Useful for testing.""" + global _providers_snapshot + with _providers_lock: + _providers.clear() + _providers_snapshot = None + + +# ============================================================ +# Utilities +# ============================================================ + +def serialize_cache_key(cache_key: Any) -> bytes: + """ + Serialize cache key to bytes for external storage. + + Returns SHA256 hash suitable for Redis/database keys. + """ + try: + serialized = pickle.dumps(cache_key, protocol=4) + return hashlib.sha256(serialized).digest() + except Exception as e: + logger.warning(f"Failed to serialize cache key: {e}") + return hashlib.sha256(str(id(cache_key)).encode()).digest() + + +def contains_nan(obj: Any) -> bool: + """ + Check if cache key contains NaN (indicates uncacheable node). + + NaN != NaN in Python, so local cache never hits. But serialized + NaN would match, causing incorrect external hits. Must skip these. + """ + if isinstance(obj, float): + try: + return math.isnan(obj) + except (TypeError, ValueError): + return False + if hasattr(obj, 'value'): # Unhashable class + val = getattr(obj, 'value', None) + if isinstance(val, float): + try: + return math.isnan(val) + except (TypeError, ValueError): + return False + if isinstance(obj, (frozenset, tuple, list, set)): + return any(contains_nan(item) for item in obj) + if isinstance(obj, dict): + return any(contains_nan(k) or contains_nan(v) for k, v in obj.items()) + return False + + +def estimate_value_size(value: CacheValue) -> int: + """Estimate serialized size in bytes. Useful for size-based filtering.""" + try: + import torch + except ImportError: + return 0 + + total = 0 + + def estimate(obj): + nonlocal total + if isinstance(obj, torch.Tensor): + total += obj.numel() * obj.element_size() + elif isinstance(obj, dict): + for v in obj.values(): + estimate(v) + elif isinstance(obj, (list, tuple)): + for item in obj: + estimate(item) + + for output in value.outputs: + estimate(output) + return total diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 326a279fc..e3386d459 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -155,6 +155,10 @@ class BasicCache: self.cache = {} self.subcaches = {} + # External cache provider support + self._is_subcache = False + self._current_prompt_id = '' + async def set_prompt(self, dynprompt, node_ids, is_changed_cache): self.dynprompt = dynprompt self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) @@ -201,20 +205,123 @@ class BasicCache: cache_key = self.cache_key_set.get_data_key(node_id) self.cache[cache_key] = value + # Notify external providers + self._notify_providers_store(node_id, cache_key, value) + def _get_immediate(self, node_id): if not self.initialized: return None cache_key = self.cache_key_set.get_data_key(node_id) + + # Check local cache first (fast path) if cache_key in self.cache: return self.cache[cache_key] - else: + + # Check external providers on local miss + external_result = self._check_providers_lookup(node_id, cache_key) + if external_result is not None: + self.cache[cache_key] = external_result # Warm local cache + return external_result + + return None + + def _notify_providers_store(self, node_id, cache_key, value): + """Notify external providers of cache store.""" + from comfy_execution.cache_provider import ( + has_cache_providers, get_cache_providers, + CacheContext, CacheValue, + serialize_cache_key, contains_nan, logger + ) + + # Fast exit conditions + if self._is_subcache: + return + if not has_cache_providers(): + return + if not self._is_cacheable_value(value): + return + if contains_nan(cache_key): + return + + context = CacheContext( + prompt_id=self._current_prompt_id, + node_id=node_id, + class_type=self._get_class_type(node_id), + cache_key=cache_key, + cache_key_bytes=serialize_cache_key(cache_key) + ) + cache_value = CacheValue(outputs=value.outputs) + + for provider in get_cache_providers(): + try: + if provider.should_cache(context, cache_value): + provider.on_store(context, cache_value) + except Exception as e: + logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}") + + def _check_providers_lookup(self, node_id, cache_key): + """Check external providers for cached result.""" + from comfy_execution.cache_provider import ( + has_cache_providers, get_cache_providers, + CacheContext, CacheValue, + serialize_cache_key, contains_nan, logger + ) + + if self._is_subcache: return None + if not has_cache_providers(): + return None + if contains_nan(cache_key): + return None + + context = CacheContext( + prompt_id=self._current_prompt_id, + node_id=node_id, + class_type=self._get_class_type(node_id), + cache_key=cache_key, + cache_key_bytes=serialize_cache_key(cache_key) + ) + + for provider in get_cache_providers(): + try: + if not provider.should_cache(context): + continue + result = provider.on_lookup(context) + if result is not None: + if not isinstance(result, CacheValue): + logger.warning(f"Provider {provider.__class__.__name__} returned invalid type") + continue + if not isinstance(result.outputs, (list, tuple)): + logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs") + continue + # Import CacheEntry here to avoid circular import at module level + from execution import CacheEntry + return CacheEntry(ui={}, outputs=list(result.outputs)) + except Exception as e: + logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}") + + return None + + def _is_cacheable_value(self, value): + """Check if value is a CacheEntry (not objects cache).""" + return hasattr(value, 'outputs') and hasattr(value, 'ui') + + def _get_class_type(self, node_id): + """Get class_type for a node.""" + if not self.initialized or not self.dynprompt: + return '' + try: + return self.dynprompt.get_node(node_id).get('class_type', '') + except Exception: + return '' async def _ensure_subcache(self, node_id, children_ids): subcache_key = self.cache_key_set.get_subcache_key(node_id) subcache = self.subcaches.get(subcache_key, None) if subcache is None: subcache = BasicCache(self.key_class) + subcache._is_subcache = True # Mark as subcache - excludes from external caching + subcache._current_prompt_id = self._current_prompt_id # Propagate prompt ID self.subcaches[subcache_key] = subcache await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) return subcache diff --git a/execution.py b/execution.py index 648f204ec..8abc30d21 100644 --- a/execution.py +++ b/execution.py @@ -669,6 +669,22 @@ class PromptExecutor: } self.add_message("execution_error", mes, broadcast=False) + def _notify_prompt_lifecycle(self, event: str, prompt_id: str): + """Notify external cache providers of prompt lifecycle events.""" + from comfy_execution.cache_provider import has_cache_providers, get_cache_providers, logger + + if not has_cache_providers(): + return + + for provider in get_cache_providers(): + try: + if event == "start": + provider.on_prompt_start(prompt_id) + elif event == "end": + provider.on_prompt_end(prompt_id) + except Exception as e: + logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}") + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) @@ -685,66 +701,77 @@ class PromptExecutor: self.status_messages = [] self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) - with torch.inference_mode(): - dynamic_prompt = DynamicPrompt(prompt) - reset_progress_state(prompt_id, dynamic_prompt) - add_progress_handler(WebUIProgressHandler(self.server)) - is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) - for cache in self.caches.all: - await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) - cache.clean_unused() + # Set prompt ID on caches for external provider integration + for cache in self.caches.all: + cache._current_prompt_id = prompt_id - cached_nodes = [] - for node_id in prompt: - if self.caches.outputs.get(node_id) is not None: - cached_nodes.append(node_id) + # Notify external cache providers of prompt start + self._notify_prompt_lifecycle("start", prompt_id) - comfy.model_management.cleanup_models_gc() - self.add_message("execution_cached", - { "nodes": cached_nodes, "prompt_id": prompt_id}, - broadcast=False) - pending_subgraph_results = {} - pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results - ui_node_outputs = {} - executed = set() - execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) - current_outputs = self.caches.outputs.all_node_ids() - for node_id in list(execute_outputs): - execution_list.add_node(node_id) + try: + with torch.inference_mode(): + dynamic_prompt = DynamicPrompt(prompt) + reset_progress_state(prompt_id, dynamic_prompt) + add_progress_handler(WebUIProgressHandler(self.server)) + is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) + for cache in self.caches.all: + await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) + cache.clean_unused() - while not execution_list.is_empty(): - node_id, error, ex = await execution_list.stage_node_execution() - if error is not None: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) - break + cached_nodes = [] + for node_id in prompt: + if self.caches.outputs.get(node_id) is not None: + cached_nodes.append(node_id) - assert node_id is not None, "Node ID should not be None at this point" - result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) - self.success = result != ExecutionResult.FAILURE - if result == ExecutionResult.FAILURE: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) - break - elif result == ExecutionResult.PENDING: - execution_list.unstage_node_execution() - else: # result == ExecutionResult.SUCCESS: - execution_list.complete_node_execution() - self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) - else: - # Only execute when the while-loop ends without break - self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) + comfy.model_management.cleanup_models_gc() + self.add_message("execution_cached", + { "nodes": cached_nodes, "prompt_id": prompt_id}, + broadcast=False) + pending_subgraph_results = {} + pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results + ui_node_outputs = {} + executed = set() + execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) + current_outputs = self.caches.outputs.all_node_ids() + for node_id in list(execute_outputs): + execution_list.add_node(node_id) - ui_outputs = {} - meta_outputs = {} - for node_id, ui_info in ui_node_outputs.items(): - ui_outputs[node_id] = ui_info["output"] - meta_outputs[node_id] = ui_info["meta"] - self.history_result = { - "outputs": ui_outputs, - "meta": meta_outputs, - } - self.server.last_node_id = None - if comfy.model_management.DISABLE_SMART_MEMORY: - comfy.model_management.unload_all_models() + while not execution_list.is_empty(): + node_id, error, ex = await execution_list.stage_node_execution() + if error is not None: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break + + assert node_id is not None, "Node ID should not be None at this point" + result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) + self.success = result != ExecutionResult.FAILURE + if result == ExecutionResult.FAILURE: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break + elif result == ExecutionResult.PENDING: + execution_list.unstage_node_execution() + else: # result == ExecutionResult.SUCCESS: + execution_list.complete_node_execution() + self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) + else: + # Only execute when the while-loop ends without break + self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) + + ui_outputs = {} + meta_outputs = {} + for node_id, ui_info in ui_node_outputs.items(): + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] + self.history_result = { + "outputs": ui_outputs, + "meta": meta_outputs, + } + self.server.last_node_id = None + if comfy.model_management.DISABLE_SMART_MEMORY: + comfy.model_management.unload_all_models() + finally: + # Notify external cache providers of prompt end + self._notify_prompt_lifecycle("end", prompt_id) async def validate_inputs(prompt_id, prompt, item, validated):