refactor: async CacheProvider API + reduce public surface

- Make on_lookup/on_store async on CacheProvider ABC
- Simplify CacheContext: replace cache_key + cache_key_bytes with
  cache_key_hash (str hex digest)
- Make registry/utility functions internal (_prefix)
- Trim comfy_api.latest.Caching exports to core API only
- Make cache get/set async throughout caching.py hierarchy
- Use asyncio.create_task for fire-and-forget on_store
- Add NaN gating before provider calls in Core
- Add await to 5 cache call sites in execution.py

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Deep Mehta
2026-03-03 12:34:25 -08:00
parent 0141af0786
commit 4cbe4fe4c7
4 changed files with 107 additions and 100 deletions

View File

@@ -1,3 +1,4 @@
import asyncio
import bisect
import gc
import itertools
@@ -200,15 +201,15 @@ class BasicCache:
def poll(self, **kwargs):
pass
def _set_immediate(self, node_id, value):
async def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
# Notify external providers
self._notify_providers_store(node_id, cache_key, value)
await self._notify_providers_store(node_id, cache_key, value)
def _get_immediate(self, node_id):
async def _get_immediate(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
@@ -218,87 +219,88 @@ class BasicCache:
return self.cache[cache_key]
# Check external providers on local miss
external_result = self._check_providers_lookup(node_id, cache_key)
external_result = await self._check_providers_lookup(node_id, cache_key)
if external_result is not None:
self.cache[cache_key] = external_result # Warm local cache
return external_result
return None
def _notify_providers_store(self, node_id, cache_key, value):
"""Notify external providers of cache store."""
async def _notify_providers_store(self, node_id, cache_key, value):
"""Notify external providers of cache store (fire-and-forget)."""
from comfy_execution.cache_provider import (
has_cache_providers, get_cache_providers,
_has_cache_providers, _get_cache_providers,
CacheContext, CacheValue,
serialize_cache_key, contains_nan, logger
_serialize_cache_key, _contains_nan, _logger
)
# Fast exit conditions
if self._is_subcache:
return
if not has_cache_providers():
if not _has_cache_providers():
return
if not self._is_external_cacheable_value(value):
return
if contains_nan(cache_key):
if _contains_nan(cache_key):
return
context = CacheContext(
prompt_id=self._current_prompt_id,
node_id=node_id,
class_type=self._get_class_type(node_id),
cache_key=cache_key,
cache_key_bytes=serialize_cache_key(cache_key)
)
context = self._build_context(node_id, cache_key)
if context is None:
return
cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
for provider in get_cache_providers():
for provider in _get_cache_providers():
try:
if provider.should_cache(context, cache_value):
provider.on_store(context, cache_value)
asyncio.create_task(self._safe_provider_store(provider, context, cache_value))
except Exception as e:
logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
_logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
def _check_providers_lookup(self, node_id, cache_key):
@staticmethod
async def _safe_provider_store(provider, context, cache_value):
"""Wrapper for fire-and-forget provider.on_store with error handling."""
from comfy_execution.cache_provider import _logger
try:
await provider.on_store(context, cache_value)
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
async def _check_providers_lookup(self, node_id, cache_key):
"""Check external providers for cached result."""
from comfy_execution.cache_provider import (
has_cache_providers, get_cache_providers,
_has_cache_providers, _get_cache_providers,
CacheContext, CacheValue,
serialize_cache_key, contains_nan, logger
_contains_nan, _logger
)
if self._is_subcache:
return None
if not has_cache_providers():
if not _has_cache_providers():
return None
if contains_nan(cache_key):
if _contains_nan(cache_key):
return None
context = CacheContext(
prompt_id=self._current_prompt_id,
node_id=node_id,
class_type=self._get_class_type(node_id),
cache_key=cache_key,
cache_key_bytes=serialize_cache_key(cache_key)
)
context = self._build_context(node_id, cache_key)
if context is None:
return None
for provider in get_cache_providers():
for provider in _get_cache_providers():
try:
if not provider.should_cache(context):
continue
result = provider.on_lookup(context)
result = await provider.on_lookup(context)
if result is not None:
if not isinstance(result, CacheValue):
logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
continue
if not isinstance(result.outputs, (list, tuple)):
logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
continue
# Import CacheEntry here to avoid circular import at module level
from execution import CacheEntry
return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs))
except Exception as e:
logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
_logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
return None
@@ -315,6 +317,16 @@ class BasicCache:
except Exception:
return ''
def _build_context(self, node_id, cache_key):
"""Build CacheContext with hash. Returns None if hashing fails on NaN."""
from comfy_execution.cache_provider import CacheContext, _serialize_cache_key
return CacheContext(
prompt_id=self._current_prompt_id,
node_id=node_id,
class_type=self._get_class_type(node_id),
cache_key_hash=_serialize_cache_key(cache_key)
)
async def _ensure_subcache(self, node_id, children_ids):
subcache_key = self.cache_key_set.get_subcache_key(node_id)
subcache = self.subcaches.get(subcache_key, None)
@@ -364,16 +376,16 @@ class HierarchicalCache(BasicCache):
return None
return cache
def get(self, node_id):
async def get(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
return cache._get_immediate(node_id)
return await cache._get_immediate(node_id)
def set(self, node_id, value):
async def set(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
cache._set_immediate(node_id, value)
await cache._set_immediate(node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id)
@@ -394,10 +406,10 @@ class NullCache:
def poll(self, **kwargs):
pass
def get(self, node_id):
async def get(self, node_id):
return None
def set(self, node_id, value):
async def set(self, node_id, value):
pass
async def ensure_subcache_for(self, node_id, children_ids):
@@ -429,18 +441,18 @@ class LRUCache(BasicCache):
del self.children[key]
self._clean_subcaches()
def get(self, node_id):
async def get(self, node_id):
self._mark_used(node_id)
return self._get_immediate(node_id)
return await self._get_immediate(node_id)
def _mark_used(self, node_id):
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key is not None:
self.used_generation[cache_key] = self.generation
def set(self, node_id, value):
async def set(self, node_id, value):
self._mark_used(node_id)
return self._set_immediate(node_id, value)
return await self._set_immediate(node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
@@ -480,13 +492,13 @@ class RAMPressureCache(LRUCache):
def clean_unused(self):
self._clean_subcaches()
def set(self, node_id, value):
async def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set(node_id, value)
await super().set(node_id, value)
def get(self, node_id):
async def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return super().get(node_id)
return await super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():