mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-08 06:39:56 +00:00
Strip verbose docstrings and section banners to match existing minimal documentation style used throughout the codebase. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
199 lines
6.4 KiB
Python
199 lines
6.4 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Any, Optional, Tuple, List
|
|
from dataclasses import dataclass
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import math
|
|
import pickle
|
|
import threading
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class CacheContext:
|
|
prompt_id: str
|
|
node_id: str
|
|
class_type: str
|
|
cache_key_hash: str # SHA256 hex digest
|
|
|
|
|
|
@dataclass
|
|
class CacheValue:
|
|
outputs: list
|
|
ui: dict = None
|
|
|
|
|
|
class CacheProvider(ABC):
|
|
"""Abstract base class for external cache providers.
|
|
Exceptions from provider methods are caught by the caller and never break execution.
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
|
"""Called on local cache miss. Return CacheValue if found, None otherwise."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
|
"""Called after local store. Dispatched via asyncio.create_task."""
|
|
pass
|
|
|
|
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
|
|
"""Return False to skip external caching for this node. Default: True."""
|
|
return True
|
|
|
|
def on_prompt_start(self, prompt_id: str) -> None:
|
|
pass
|
|
|
|
def on_prompt_end(self, prompt_id: str) -> None:
|
|
pass
|
|
|
|
|
|
_providers: List[CacheProvider] = []
|
|
_providers_lock = threading.Lock()
|
|
_providers_snapshot: Optional[Tuple[CacheProvider, ...]] = None
|
|
|
|
|
|
def register_cache_provider(provider: CacheProvider) -> None:
|
|
"""Register an external cache provider. Providers are called in registration order."""
|
|
global _providers_snapshot
|
|
with _providers_lock:
|
|
if provider in _providers:
|
|
_logger.warning(f"Provider {provider.__class__.__name__} already registered")
|
|
return
|
|
_providers.append(provider)
|
|
_providers_snapshot = None
|
|
_logger.info(f"Registered cache provider: {provider.__class__.__name__}")
|
|
|
|
|
|
def unregister_cache_provider(provider: CacheProvider) -> None:
|
|
global _providers_snapshot
|
|
with _providers_lock:
|
|
try:
|
|
_providers.remove(provider)
|
|
_providers_snapshot = None
|
|
_logger.info(f"Unregistered cache provider: {provider.__class__.__name__}")
|
|
except ValueError:
|
|
_logger.warning(f"Provider {provider.__class__.__name__} was not registered")
|
|
|
|
|
|
def _get_cache_providers() -> Tuple[CacheProvider, ...]:
|
|
global _providers_snapshot
|
|
snapshot = _providers_snapshot
|
|
if snapshot is not None:
|
|
return snapshot
|
|
with _providers_lock:
|
|
if _providers_snapshot is not None:
|
|
return _providers_snapshot
|
|
_providers_snapshot = tuple(_providers)
|
|
return _providers_snapshot
|
|
|
|
|
|
def _has_cache_providers() -> bool:
|
|
return bool(_providers)
|
|
|
|
|
|
def _clear_cache_providers() -> None:
|
|
global _providers_snapshot
|
|
with _providers_lock:
|
|
_providers.clear()
|
|
_providers_snapshot = None
|
|
|
|
|
|
def _canonicalize(obj: Any) -> Any:
|
|
# Convert to canonical JSON-serializable form with deterministic ordering.
|
|
# Frozensets have non-deterministic iteration order between Python sessions.
|
|
if isinstance(obj, frozenset):
|
|
# Sort frozenset items for deterministic ordering
|
|
return ("__frozenset__", sorted(
|
|
[_canonicalize(item) for item in obj],
|
|
key=lambda x: json.dumps(x, sort_keys=True)
|
|
))
|
|
elif isinstance(obj, set):
|
|
return ("__set__", sorted(
|
|
[_canonicalize(item) for item in obj],
|
|
key=lambda x: json.dumps(x, sort_keys=True)
|
|
))
|
|
elif isinstance(obj, tuple):
|
|
return ("__tuple__", [_canonicalize(item) for item in obj])
|
|
elif isinstance(obj, list):
|
|
return [_canonicalize(item) for item in obj]
|
|
elif isinstance(obj, dict):
|
|
return {str(k): _canonicalize(v) for k, v in sorted(obj.items())}
|
|
elif isinstance(obj, (int, float, str, bool, type(None))):
|
|
return obj
|
|
elif isinstance(obj, bytes):
|
|
return ("__bytes__", obj.hex())
|
|
elif hasattr(obj, 'value'):
|
|
# Handle Unhashable class from ComfyUI
|
|
return ("__unhashable__", _canonicalize(getattr(obj, 'value', None)))
|
|
else:
|
|
# For other types, use repr as fallback
|
|
return ("__repr__", repr(obj))
|
|
|
|
|
|
def _serialize_cache_key(cache_key: Any) -> Optional[str]:
|
|
# Returns deterministic SHA256 hex digest, or None on failure.
|
|
# Uses JSON (not pickle) because pickle is non-deterministic across sessions.
|
|
try:
|
|
canonical = _canonicalize(cache_key)
|
|
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
|
|
return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
|
|
except Exception as e:
|
|
_logger.warning(f"Failed to serialize cache key: {e}")
|
|
# Fallback to pickle (non-deterministic but better than nothing)
|
|
try:
|
|
serialized = pickle.dumps(cache_key, protocol=4)
|
|
return hashlib.sha256(serialized).hexdigest()
|
|
except Exception as fallback_error:
|
|
_logger.warning(f"Failed pickle fallback for cache key: {fallback_error}")
|
|
return None
|
|
|
|
|
|
def _contains_nan(obj: Any) -> bool:
|
|
# NaN != NaN so local cache never hits, but serialized NaN would match.
|
|
# Skip external caching for keys containing NaN.
|
|
if isinstance(obj, float):
|
|
try:
|
|
return math.isnan(obj)
|
|
except (TypeError, ValueError):
|
|
return False
|
|
if hasattr(obj, 'value'): # Unhashable class
|
|
val = getattr(obj, 'value', None)
|
|
if isinstance(val, float):
|
|
try:
|
|
return math.isnan(val)
|
|
except (TypeError, ValueError):
|
|
return False
|
|
if isinstance(obj, (frozenset, tuple, list, set)):
|
|
return any(_contains_nan(item) for item in obj)
|
|
if isinstance(obj, dict):
|
|
return any(_contains_nan(k) or _contains_nan(v) for k, v in obj.items())
|
|
return False
|
|
|
|
|
|
def _estimate_value_size(value: CacheValue) -> int:
|
|
try:
|
|
import torch
|
|
except ImportError:
|
|
return 0
|
|
|
|
total = 0
|
|
|
|
def estimate(obj):
|
|
nonlocal total
|
|
if isinstance(obj, torch.Tensor):
|
|
total += obj.numel() * obj.element_size()
|
|
elif isinstance(obj, dict):
|
|
for v in obj.values():
|
|
estimate(v)
|
|
elif isinstance(obj, (list, tuple)):
|
|
for item in obj:
|
|
estimate(item)
|
|
|
|
for output in value.outputs:
|
|
estimate(output)
|
|
return total
|