Revert "Revert "feat: Add CacheProvider API for external distributed caching …"

This reverts commit d1d53c14be.
This commit is contained in:
Deep Mehta
2026-03-12 19:56:21 -07:00
committed by GitHub
parent 63d1bbdb40
commit e6bb7f03dd
7 changed files with 859 additions and 83 deletions

View File

@@ -1,3 +1,4 @@
import asyncio
import bisect
import gc
import itertools
@@ -154,6 +155,7 @@ class BasicCache:
self.cache_key_set: CacheKeySet
self.cache = {}
self.subcaches = {}
self._pending_store_tasks: set = set()
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt
@@ -196,18 +198,134 @@ class BasicCache:
def poll(self, **kwargs):
pass
def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
def _get_immediate(self, node_id):
def get_local(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
else:
return None
def set_local(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
async def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
await self._notify_providers_store(node_id, cache_key, value)
async def _get_immediate(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
external_result = await self._check_providers_lookup(node_id, cache_key)
if external_result is not None:
self.cache[cache_key] = external_result
return external_result
return None
async def _notify_providers_store(self, node_id, cache_key, value):
from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers,
CacheValue, _contains_self_unequal, _logger
)
if not _has_cache_providers():
return
if not self._is_external_cacheable_value(value):
return
if _contains_self_unequal(cache_key):
return
context = self._build_context(node_id, cache_key)
if context is None:
return
cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
for provider in _get_cache_providers():
try:
if provider.should_cache(context, cache_value):
task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value))
self._pending_store_tasks.add(task)
task.add_done_callback(self._pending_store_tasks.discard)
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
@staticmethod
async def _safe_provider_store(provider, context, cache_value):
from comfy_execution.cache_provider import _logger
try:
await provider.on_store(context, cache_value)
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
async def _check_providers_lookup(self, node_id, cache_key):
from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers,
CacheValue, _contains_self_unequal, _logger
)
if not _has_cache_providers():
return None
if _contains_self_unequal(cache_key):
return None
context = self._build_context(node_id, cache_key)
if context is None:
return None
for provider in _get_cache_providers():
try:
if not provider.should_cache(context):
continue
result = await provider.on_lookup(context)
if result is not None:
if not isinstance(result, CacheValue):
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
continue
if not isinstance(result.outputs, (list, tuple)):
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
continue
from execution import CacheEntry
return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs))
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
return None
def _is_external_cacheable_value(self, value):
return hasattr(value, 'outputs') and hasattr(value, 'ui')
def _get_class_type(self, node_id):
if not self.initialized or not self.dynprompt:
return ''
try:
return self.dynprompt.get_node(node_id).get('class_type', '')
except Exception:
return ''
def _build_context(self, node_id, cache_key):
from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger
try:
cache_key_hash = _serialize_cache_key(cache_key)
if cache_key_hash is None:
return None
return CacheContext(
node_id=node_id,
class_type=self._get_class_type(node_id),
cache_key_hash=cache_key_hash,
)
except Exception as e:
_logger.warning(f"Failed to build cache context for node {node_id}: {e}")
return None
async def _ensure_subcache(self, node_id, children_ids):
@@ -257,16 +375,27 @@ class HierarchicalCache(BasicCache):
return None
return cache
def get(self, node_id):
async def get(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
return cache._get_immediate(node_id)
return await cache._get_immediate(node_id)
def set(self, node_id, value):
def get_local(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
return BasicCache.get_local(cache, node_id)
async def set(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
cache._set_immediate(node_id, value)
await cache._set_immediate(node_id, value)
def set_local(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
BasicCache.set_local(cache, node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id)
@@ -287,10 +416,16 @@ class NullCache:
def poll(self, **kwargs):
pass
def get(self, node_id):
async def get(self, node_id):
return None
def set(self, node_id, value):
def get_local(self, node_id):
return None
async def set(self, node_id, value):
pass
def set_local(self, node_id, value):
pass
async def ensure_subcache_for(self, node_id, children_ids):
@@ -322,18 +457,18 @@ class LRUCache(BasicCache):
del self.children[key]
self._clean_subcaches()
def get(self, node_id):
async def get(self, node_id):
self._mark_used(node_id)
return self._get_immediate(node_id)
return await self._get_immediate(node_id)
def _mark_used(self, node_id):
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key is not None:
self.used_generation[cache_key] = self.generation
def set(self, node_id, value):
async def set(self, node_id, value):
self._mark_used(node_id)
return self._set_immediate(node_id, value)
return await self._set_immediate(node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
@@ -373,13 +508,13 @@ class RAMPressureCache(LRUCache):
def clean_unused(self):
self._clean_subcaches()
def set(self, node_id, value):
async def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set(node_id, value)
await super().set(node_id, value)
def get(self, node_id):
async def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return super().get(node_id)
return await super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():