refactor: async CacheProvider API + reduce public surface

- Make on_lookup/on_store async on CacheProvider ABC
- Simplify CacheContext: replace cache_key + cache_key_bytes with
  cache_key_hash (str hex digest)
- Make registry/utility functions internal (_prefix)
- Trim comfy_api.latest.Caching exports to core API only
- Make cache get/set async throughout caching.py hierarchy
- Use asyncio.create_task for fire-and-forget on_store
- Add NaN gating before provider calls in Core
- Add await to 5 cache call sites in execution.py

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Deep Mehta
2026-03-03 12:34:25 -08:00
parent 0141af0786
commit 4cbe4fe4c7
4 changed files with 107 additions and 100 deletions

View File

@@ -13,12 +13,12 @@ Example usage:
)
class MyRedisProvider(CacheProvider):
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
# Check Redis/GCS for cached result
...
def on_store(self, context: CacheContext, value: CacheValue) -> None:
# Store to Redis/GCS (can be async internally)
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
# Store to Redis/GCS
...
register_cache_provider(MyRedisProvider())
@@ -34,7 +34,7 @@ import math
import pickle
import threading
logger = logging.getLogger(__name__)
_logger = logging.getLogger(__name__)
# ============================================================
@@ -47,8 +47,7 @@ class CacheContext:
prompt_id: str # Current prompt execution ID
node_id: str # Node being cached
class_type: str # Node class type (e.g., "KSampler")
cache_key: Any # Raw cache key (frozenset structure)
cache_key_bytes: bytes # SHA256 hash for external storage key
cache_key_hash: str # SHA256 hex digest for external storage key
@dataclass
@@ -71,9 +70,9 @@ class CacheProvider(ABC):
"""
Abstract base class for external cache providers.
Thread Safety:
Providers may be called from multiple threads. Implementations
must be thread-safe.
Async Safety:
Provider methods are called from async context. Implementations
can use async I/O (aiohttp, asyncpg, etc.) directly.
Error Handling:
All methods are wrapped in try/except by the caller. Exceptions
@@ -81,12 +80,12 @@ class CacheProvider(ABC):
Performance Guidelines:
- on_lookup: Should complete in <500ms (including network)
- on_store: Can be async internally (fire-and-forget)
- on_store: Fire-and-forget via asyncio.create_task
- should_cache: Should be fast (<1ms), called frequently
"""
@abstractmethod
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
"""
Check external storage for cached result.
@@ -102,14 +101,14 @@ class CacheProvider(ABC):
pass
@abstractmethod
def on_store(self, context: CacheContext, value: CacheValue) -> None:
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
"""
Store value to external cache.
Called AFTER value is stored in local cache.
Dispatched as asyncio.create_task (fire-and-forget).
Important:
- Can be fire-and-forget (async internally)
- Should never block execution
- Handle serialization failures gracefully
"""
@@ -123,7 +122,7 @@ class CacheProvider(ABC):
Return False to skip external caching for this node.
Implementations can filter based on context.class_type, value size,
or any custom logic. Use estimate_value_size() to get value size.
or any custom logic. Use _estimate_value_size() to get value size.
Default: Returns True (cache everything).
"""
@@ -157,11 +156,11 @@ def register_cache_provider(provider: CacheProvider) -> None:
global _providers_snapshot
with _providers_lock:
if provider in _providers:
logger.warning(f"Provider {provider.__class__.__name__} already registered")
_logger.warning(f"Provider {provider.__class__.__name__} already registered")
return
_providers.append(provider)
_providers_snapshot = None # Invalidate cache
logger.info(f"Registered cache provider: {provider.__class__.__name__}")
_logger.info(f"Registered cache provider: {provider.__class__.__name__}")
def unregister_cache_provider(provider: CacheProvider) -> None:
@@ -171,13 +170,13 @@ def unregister_cache_provider(provider: CacheProvider) -> None:
try:
_providers.remove(provider)
_providers_snapshot = None
logger.info(f"Unregistered cache provider: {provider.__class__.__name__}")
_logger.info(f"Unregistered cache provider: {provider.__class__.__name__}")
except ValueError:
logger.warning(f"Provider {provider.__class__.__name__} was not registered")
_logger.warning(f"Provider {provider.__class__.__name__} was not registered")
def get_cache_providers() -> Tuple[CacheProvider, ...]:
"""Get registered providers (cached for performance)."""
def _get_cache_providers() -> Tuple[CacheProvider, ...]:
"""Get registered providers (cached for performance). Internal."""
global _providers_snapshot
snapshot = _providers_snapshot
if snapshot is not None:
@@ -189,13 +188,13 @@ def get_cache_providers() -> Tuple[CacheProvider, ...]:
return _providers_snapshot
def has_cache_providers() -> bool:
"""Fast check if any providers registered (no lock)."""
def _has_cache_providers() -> bool:
"""Fast check if any providers registered (no lock). Internal."""
return bool(_providers)
def clear_cache_providers() -> None:
"""Remove all providers. Useful for testing."""
def _clear_cache_providers() -> None:
"""Remove all providers. Useful for testing. Internal."""
global _providers_snapshot
with _providers_lock:
_providers.clear()
@@ -203,7 +202,7 @@ def clear_cache_providers() -> None:
# ============================================================
# Utilities
# Internal Utilities
# ============================================================
def _canonicalize(obj: Any) -> Any:
@@ -243,11 +242,11 @@ def _canonicalize(obj: Any) -> Any:
return ("__repr__", repr(obj))
def serialize_cache_key(cache_key: Any) -> bytes:
def _serialize_cache_key(cache_key: Any) -> str:
"""
Serialize cache key to bytes for external storage.
Serialize cache key to a hex digest string for external storage.
Returns SHA256 hash suitable for Redis/database keys.
Returns SHA256 hex string suitable for Redis/database keys.
Note: Uses canonicalize + JSON serialization instead of pickle because
pickle is NOT deterministic across Python sessions due to hash randomization
@@ -257,18 +256,18 @@ def serialize_cache_key(cache_key: Any) -> bytes:
try:
canonical = _canonicalize(cache_key)
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
return hashlib.sha256(json_str.encode('utf-8')).digest()
return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
except Exception as e:
logger.warning(f"Failed to serialize cache key: {e}")
_logger.warning(f"Failed to serialize cache key: {e}")
# Fallback to pickle (non-deterministic but better than nothing)
try:
serialized = pickle.dumps(cache_key, protocol=4)
return hashlib.sha256(serialized).digest()
return hashlib.sha256(serialized).hexdigest()
except Exception:
return hashlib.sha256(str(id(cache_key)).encode()).digest()
return hashlib.sha256(str(id(cache_key)).encode()).hexdigest()
def contains_nan(obj: Any) -> bool:
def _contains_nan(obj: Any) -> bool:
"""
Check if cache key contains NaN (indicates uncacheable node).
@@ -288,14 +287,14 @@ def contains_nan(obj: Any) -> bool:
except (TypeError, ValueError):
return False
if isinstance(obj, (frozenset, tuple, list, set)):
return any(contains_nan(item) for item in obj)
return any(_contains_nan(item) for item in obj)
if isinstance(obj, dict):
return any(contains_nan(k) or contains_nan(v) for k, v in obj.items())
return any(_contains_nan(k) or _contains_nan(v) for k, v in obj.items())
return False
def estimate_value_size(value: CacheValue) -> int:
"""Estimate serialized size in bytes. Useful for size-based filtering."""
def _estimate_value_size(value: CacheValue) -> int:
"""Estimate serialized size in bytes. Useful for size-based filtering. Internal."""
try:
import torch
except ImportError: