mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-13 09:10:12 +00:00
* feat: Add CacheProvider API for external distributed caching Introduces a public API for external cache providers, enabling distributed caching across multiple ComfyUI instances (e.g., Kubernetes pods). New files: - comfy_execution/cache_provider.py: CacheProvider ABC, CacheContext/CacheValue dataclasses, thread-safe provider registry, serialization utilities Modified files: - comfy_execution/caching.py: Add provider hooks to BasicCache (_notify_providers_store, _check_providers_lookup), subcache exclusion, prompt ID propagation - execution.py: Add prompt lifecycle hooks (on_prompt_start/on_prompt_end) to PromptExecutor, set _current_prompt_id on caches Key features: - Local-first caching (check local before external for performance) - NaN detection to prevent incorrect external cache hits - Subcache exclusion (ephemeral subgraph results not cached externally) - Thread-safe provider snapshot caching - Graceful error handling (provider errors logged, never break execution) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * fix: use deterministic hash for cache keys instead of pickle Pickle serialization is NOT deterministic across Python sessions due to hash randomization affecting frozenset iteration order. This causes distributed caching to fail because different pods compute different hashes for identical cache keys. Fix: Use _canonicalize() + JSON serialization which ensures deterministic ordering regardless of Python's hash randomization. This is critical for cross-pod cache key consistency in Kubernetes deployments. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * test: add unit tests for CacheProvider API - Add comprehensive tests for _canonicalize deterministic ordering - Add tests for serialize_cache_key hash consistency - Add tests for contains_nan utility - Add tests for estimate_value_size - Add tests for provider registry (register, unregister, clear) - Move json import to top-level (fix inline import) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * style: remove unused imports in test_cache_provider.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * fix: move _torch_available before usage and use importlib.util.find_spec Fixes ruff F821 (undefined name) and F401 (unused import) errors. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * fix: use hashable types in frozenset test and add dict test Frozensets can only contain hashable types, so use nested frozensets instead of dicts. Added separate test for dict handling via serialize_cache_key. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * refactor: expose CacheProvider API via comfy_api.latest.Caching - Add Caching class to comfy_api/latest/__init__.py that re-exports from comfy_execution.cache_provider (source of truth) - Fix docstring: "Skip large values" instead of "Skip small values" (small compute-heavy values are good cache targets) - Maintain backward compatibility: comfy_execution.cache_provider imports still work Usage: from comfy_api.latest import Caching class MyProvider(Caching.CacheProvider): def on_lookup(self, context): ... def on_store(self, context, value): ... Caching.register_provider(MyProvider()) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * docs: clarify should_cache filtering criteria Change docstring from "Skip large values" to "Skip if download time > compute time" which better captures the cost/benefit tradeoff for external caching. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * docs: make should_cache docstring implementation-agnostic Remove prescriptive filtering suggestions - let implementations decide their own caching logic based on their use case. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * feat: add optional ui field to CacheValue - Add ui field to CacheValue dataclass (default None) - Pass ui when creating CacheValue for external providers - Use result.ui (or default {}) when returning from external cache lookup This allows external cache implementations to store/retrieve UI data if desired, while remaining optional for implementations that skip it. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * refactor: rename _is_cacheable_value to _is_external_cacheable_value Clearer name since objects are also cached locally - this specifically checks for external caching eligibility. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * refactor: async CacheProvider API + reduce public surface - Make on_lookup/on_store async on CacheProvider ABC - Simplify CacheContext: replace cache_key + cache_key_bytes with cache_key_hash (str hex digest) - Make registry/utility functions internal (_prefix) - Trim comfy_api.latest.Caching exports to core API only - Make cache get/set async throughout caching.py hierarchy - Use asyncio.create_task for fire-and-forget on_store - Add NaN gating before provider calls in Core - Add await to 5 cache call sites in execution.py Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: remove unused imports (ruff) and update tests for internal API - Remove unused CacheContext and _serialize_cache_key imports from caching.py (now handled by _build_context helper) - Update test_cache_provider.py to use _-prefixed internal names - Update tests for new CacheContext.cache_key_hash field (str) - Make MockCacheProvider methods async to match ABC Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: address coderabbit review feedback - Add try/except to _build_context, return None when hash fails - Return None from _serialize_cache_key on total failure (no id()-based fallback) - Replace hex-like test literal with non-secret placeholder Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: use _-prefixed imports in _notify_prompt_lifecycle The lifecycle notification method was importing the old non-prefixed names (has_cache_providers, get_cache_providers, logger) which no longer exist after the API cleanup. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: add sync get_local/set_local for graph traversal ExecutionList in graph.py calls output_cache.get() and .set() from sync methods (is_cached, cache_link, get_cache). These cannot await the now-async get/set. Add get_local/set_local that bypass external providers and only access the local dict — which is all graph traversal needs. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * chore: remove cloud-specific language from cache provider API Make all docstrings and comments generic for the OSS codebase. Remove references to Kubernetes, Redis, GCS, pods, and other infrastructure-specific terminology. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * style: align documentation with codebase conventions Strip verbose docstrings and section banners to match existing minimal documentation style used throughout the codebase. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: add usage example to Caching class, remove pickle fallback - Add docstring with usage example to Caching class matching the convention used by sibling APIs (Execution.set_progress, ComfyExtension) - Remove non-deterministic pickle fallback from _serialize_cache_key; return None on JSON failure instead of producing unretrievable hashes - Move cache_provider imports to top of execution.py (no circular dep) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor: move public types to comfy_api, eager provider snapshot Address review feedback: - Move CacheProvider/CacheContext/CacheValue definitions to comfy_api/latest/_caching.py (source of truth for public API) - comfy_execution/cache_provider.py re-exports types from there - Build _providers_snapshot eagerly on register/unregister instead of lazy memoization in _get_cache_providers Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: generalize self-inequality check, fail-closed canonicalization Address review feedback from guill: - Rename _contains_nan to _contains_self_unequal, use not (x == x) instead of math.isnan to catch any self-unequal value - Remove Unhashable and repr() fallbacks from _canonicalize; raise ValueError for unknown types so _serialize_cache_key returns None and external caching is skipped (fail-closed) - Update tests for renamed function and new fail-closed behavior Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: suppress ruff F401 for re-exported CacheContext CacheContext is imported from _caching and re-exported for use by caching.py. Add noqa comment to satisfy the linter. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: enable external caching for subcache (expanded) nodes Subcache nodes (from node expansion) now participate in external provider store/lookup. Previously skipped to avoid duplicates, but the cost of missing partial-expansion cache hits outweighs redundant stores — especially with looping behavior on the horizon. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: wrap register/unregister as explicit static methods Define register_provider and unregister_provider as wrapper functions in the Caching class instead of re-importing. This locks the public API signature in comfy_api/ so internal changes can't accidentally break it. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: use debug-level logging for provider registration Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: follow ProxiedSingleton pattern for Caching class Add Caching as a nested class inside ComfyAPI_latest inheriting from ProxiedSingleton with async instance methods, matching the Execution and NodeReplacement patterns. Retains standalone Caching class for direct import convenience. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: inline registration logic in Caching class Follow the Execution/NodeReplacement pattern — the public API methods contain the actual logic operating on cache_provider module state, not wrapper functions delegating to free functions. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: single Caching definition inside ComfyAPI_latest Remove duplicate standalone Caching class. Define it once as a nested class in ComfyAPI_latest (matching Execution/NodeReplacement pattern), with a module-level alias for import convenience. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: remove prompt_id from CacheContext, type-safe canonicalization Remove prompt_id from CacheContext — it's not relevant for cache matching and added unnecessary plumbing (_current_prompt_id on every cache). Lifecycle hooks still receive prompt_id directly. Include type name in canonicalized primitives so that int 7 and str "7" produce distinct hashes. Also canonicalize dict keys properly instead of str() coercion. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: address review feedback on cache provider API - Hold references to pending store tasks to prevent "Task was destroyed but it is still pending" warnings (bigcat88) - Parallel cache lookups with asyncio.gather instead of sequential awaits for better performance (bigcat88) - Delegate Caching.register/unregister_provider to existing functions in cache_provider.py instead of reimplementing (bigcat88) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com>
559 lines
20 KiB
Python
559 lines
20 KiB
Python
import asyncio
|
|
import bisect
|
|
import gc
|
|
import itertools
|
|
import psutil
|
|
import time
|
|
import torch
|
|
from typing import Sequence, Mapping, Dict
|
|
from comfy_execution.graph import DynamicPrompt
|
|
from abc import ABC, abstractmethod
|
|
|
|
import nodes
|
|
|
|
from comfy_execution.graph_utils import is_link
|
|
|
|
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
|
|
|
|
|
|
def include_unique_id_in_input(class_type: str) -> bool:
|
|
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
|
|
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
|
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
|
|
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
|
|
|
class CacheKeySet(ABC):
|
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
|
self.keys = {}
|
|
self.subcache_keys = {}
|
|
|
|
@abstractmethod
|
|
async def add_keys(self, node_ids):
|
|
raise NotImplementedError()
|
|
|
|
def all_node_ids(self):
|
|
return set(self.keys.keys())
|
|
|
|
def get_used_keys(self):
|
|
return self.keys.values()
|
|
|
|
def get_used_subcache_keys(self):
|
|
return self.subcache_keys.values()
|
|
|
|
def get_data_key(self, node_id):
|
|
return self.keys.get(node_id, None)
|
|
|
|
def get_subcache_key(self, node_id):
|
|
return self.subcache_keys.get(node_id, None)
|
|
|
|
class Unhashable:
|
|
def __init__(self):
|
|
self.value = float("NaN")
|
|
|
|
def to_hashable(obj):
|
|
# So that we don't infinitely recurse since frozenset and tuples
|
|
# are Sequences.
|
|
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
|
|
return obj
|
|
elif isinstance(obj, Mapping):
|
|
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
|
elif isinstance(obj, Sequence):
|
|
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
|
|
else:
|
|
# TODO - Support other objects like tensors?
|
|
return Unhashable()
|
|
|
|
class CacheKeySetID(CacheKeySet):
|
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
|
super().__init__(dynprompt, node_ids, is_changed_cache)
|
|
self.dynprompt = dynprompt
|
|
|
|
async def add_keys(self, node_ids):
|
|
for node_id in node_ids:
|
|
if node_id in self.keys:
|
|
continue
|
|
if not self.dynprompt.has_node(node_id):
|
|
continue
|
|
node = self.dynprompt.get_node(node_id)
|
|
self.keys[node_id] = (node_id, node["class_type"])
|
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
|
|
|
class CacheKeySetInputSignature(CacheKeySet):
|
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
|
super().__init__(dynprompt, node_ids, is_changed_cache)
|
|
self.dynprompt = dynprompt
|
|
self.is_changed_cache = is_changed_cache
|
|
|
|
def include_node_id_in_input(self) -> bool:
|
|
return False
|
|
|
|
async def add_keys(self, node_ids):
|
|
for node_id in node_ids:
|
|
if node_id in self.keys:
|
|
continue
|
|
if not self.dynprompt.has_node(node_id):
|
|
continue
|
|
node = self.dynprompt.get_node(node_id)
|
|
self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id)
|
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
|
|
|
async def get_node_signature(self, dynprompt, node_id):
|
|
signature = []
|
|
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
|
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
|
for ancestor_id in ancestors:
|
|
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
|
return to_hashable(signature)
|
|
|
|
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
|
if not dynprompt.has_node(node_id):
|
|
# This node doesn't exist -- we can't cache it.
|
|
return [float("NaN")]
|
|
node = dynprompt.get_node(node_id)
|
|
class_type = node["class_type"]
|
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
|
signature = [class_type, await self.is_changed_cache.get(node_id)]
|
|
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
|
signature.append(node_id)
|
|
inputs = node["inputs"]
|
|
for key in sorted(inputs.keys()):
|
|
if is_link(inputs[key]):
|
|
(ancestor_id, ancestor_socket) = inputs[key]
|
|
ancestor_index = ancestor_order_mapping[ancestor_id]
|
|
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
|
|
else:
|
|
signature.append((key, inputs[key]))
|
|
return signature
|
|
|
|
# This function returns a list of all ancestors of the given node. The order of the list is
|
|
# deterministic based on which specific inputs the ancestor is connected by.
|
|
def get_ordered_ancestry(self, dynprompt, node_id):
|
|
ancestors = []
|
|
order_mapping = {}
|
|
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
|
|
return ancestors, order_mapping
|
|
|
|
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
|
if not dynprompt.has_node(node_id):
|
|
return
|
|
inputs = dynprompt.get_node(node_id)["inputs"]
|
|
input_keys = sorted(inputs.keys())
|
|
for key in input_keys:
|
|
if is_link(inputs[key]):
|
|
ancestor_id = inputs[key][0]
|
|
if ancestor_id not in order_mapping:
|
|
ancestors.append(ancestor_id)
|
|
order_mapping[ancestor_id] = len(ancestors) - 1
|
|
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
|
|
|
class BasicCache:
|
|
def __init__(self, key_class):
|
|
self.key_class = key_class
|
|
self.initialized = False
|
|
self.dynprompt: DynamicPrompt
|
|
self.cache_key_set: CacheKeySet
|
|
self.cache = {}
|
|
self.subcaches = {}
|
|
self._pending_store_tasks: set = set()
|
|
|
|
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
|
self.dynprompt = dynprompt
|
|
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
|
await self.cache_key_set.add_keys(node_ids)
|
|
self.is_changed_cache = is_changed_cache
|
|
self.initialized = True
|
|
|
|
def all_node_ids(self):
|
|
assert self.initialized
|
|
node_ids = self.cache_key_set.all_node_ids()
|
|
for subcache in self.subcaches.values():
|
|
node_ids = node_ids.union(subcache.all_node_ids())
|
|
return node_ids
|
|
|
|
def _clean_cache(self):
|
|
preserve_keys = set(self.cache_key_set.get_used_keys())
|
|
to_remove = []
|
|
for key in self.cache:
|
|
if key not in preserve_keys:
|
|
to_remove.append(key)
|
|
for key in to_remove:
|
|
del self.cache[key]
|
|
|
|
def _clean_subcaches(self):
|
|
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
|
|
|
|
to_remove = []
|
|
for key in self.subcaches:
|
|
if key not in preserve_subcaches:
|
|
to_remove.append(key)
|
|
for key in to_remove:
|
|
del self.subcaches[key]
|
|
|
|
def clean_unused(self):
|
|
assert self.initialized
|
|
self._clean_cache()
|
|
self._clean_subcaches()
|
|
|
|
def poll(self, **kwargs):
|
|
pass
|
|
|
|
def get_local(self, node_id):
|
|
if not self.initialized:
|
|
return None
|
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
|
if cache_key in self.cache:
|
|
return self.cache[cache_key]
|
|
return None
|
|
|
|
def set_local(self, node_id, value):
|
|
assert self.initialized
|
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
|
self.cache[cache_key] = value
|
|
|
|
async def _set_immediate(self, node_id, value):
|
|
assert self.initialized
|
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
|
self.cache[cache_key] = value
|
|
|
|
await self._notify_providers_store(node_id, cache_key, value)
|
|
|
|
async def _get_immediate(self, node_id):
|
|
if not self.initialized:
|
|
return None
|
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
|
|
|
if cache_key in self.cache:
|
|
return self.cache[cache_key]
|
|
|
|
external_result = await self._check_providers_lookup(node_id, cache_key)
|
|
if external_result is not None:
|
|
self.cache[cache_key] = external_result
|
|
return external_result
|
|
|
|
return None
|
|
|
|
async def _notify_providers_store(self, node_id, cache_key, value):
|
|
from comfy_execution.cache_provider import (
|
|
_has_cache_providers, _get_cache_providers,
|
|
CacheValue, _contains_self_unequal, _logger
|
|
)
|
|
|
|
if not _has_cache_providers():
|
|
return
|
|
if not self._is_external_cacheable_value(value):
|
|
return
|
|
if _contains_self_unequal(cache_key):
|
|
return
|
|
|
|
context = self._build_context(node_id, cache_key)
|
|
if context is None:
|
|
return
|
|
cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
|
|
|
|
for provider in _get_cache_providers():
|
|
try:
|
|
if provider.should_cache(context, cache_value):
|
|
task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value))
|
|
self._pending_store_tasks.add(task)
|
|
task.add_done_callback(self._pending_store_tasks.discard)
|
|
except Exception as e:
|
|
_logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
|
|
|
|
@staticmethod
|
|
async def _safe_provider_store(provider, context, cache_value):
|
|
from comfy_execution.cache_provider import _logger
|
|
try:
|
|
await provider.on_store(context, cache_value)
|
|
except Exception as e:
|
|
_logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
|
|
|
|
async def _check_providers_lookup(self, node_id, cache_key):
|
|
from comfy_execution.cache_provider import (
|
|
_has_cache_providers, _get_cache_providers,
|
|
CacheValue, _contains_self_unequal, _logger
|
|
)
|
|
|
|
if not _has_cache_providers():
|
|
return None
|
|
if _contains_self_unequal(cache_key):
|
|
return None
|
|
|
|
context = self._build_context(node_id, cache_key)
|
|
if context is None:
|
|
return None
|
|
|
|
for provider in _get_cache_providers():
|
|
try:
|
|
if not provider.should_cache(context):
|
|
continue
|
|
result = await provider.on_lookup(context)
|
|
if result is not None:
|
|
if not isinstance(result, CacheValue):
|
|
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
|
|
continue
|
|
if not isinstance(result.outputs, (list, tuple)):
|
|
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
|
|
continue
|
|
from execution import CacheEntry
|
|
return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs))
|
|
except Exception as e:
|
|
_logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
|
|
|
|
return None
|
|
|
|
def _is_external_cacheable_value(self, value):
|
|
return hasattr(value, 'outputs') and hasattr(value, 'ui')
|
|
|
|
def _get_class_type(self, node_id):
|
|
if not self.initialized or not self.dynprompt:
|
|
return ''
|
|
try:
|
|
return self.dynprompt.get_node(node_id).get('class_type', '')
|
|
except Exception:
|
|
return ''
|
|
|
|
def _build_context(self, node_id, cache_key):
|
|
from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger
|
|
try:
|
|
cache_key_hash = _serialize_cache_key(cache_key)
|
|
if cache_key_hash is None:
|
|
return None
|
|
return CacheContext(
|
|
node_id=node_id,
|
|
class_type=self._get_class_type(node_id),
|
|
cache_key_hash=cache_key_hash,
|
|
)
|
|
except Exception as e:
|
|
_logger.warning(f"Failed to build cache context for node {node_id}: {e}")
|
|
return None
|
|
|
|
async def _ensure_subcache(self, node_id, children_ids):
|
|
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
|
subcache = self.subcaches.get(subcache_key, None)
|
|
if subcache is None:
|
|
subcache = BasicCache(self.key_class)
|
|
self.subcaches[subcache_key] = subcache
|
|
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
|
return subcache
|
|
|
|
def _get_subcache(self, node_id):
|
|
assert self.initialized
|
|
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
|
if subcache_key in self.subcaches:
|
|
return self.subcaches[subcache_key]
|
|
else:
|
|
return None
|
|
|
|
def recursive_debug_dump(self):
|
|
result = []
|
|
for key in self.cache:
|
|
result.append({"key": key, "value": self.cache[key]})
|
|
for key in self.subcaches:
|
|
result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
|
|
return result
|
|
|
|
class HierarchicalCache(BasicCache):
|
|
def __init__(self, key_class):
|
|
super().__init__(key_class)
|
|
|
|
def _get_cache_for(self, node_id):
|
|
assert self.dynprompt is not None
|
|
parent_id = self.dynprompt.get_parent_node_id(node_id)
|
|
if parent_id is None:
|
|
return self
|
|
|
|
hierarchy = []
|
|
while parent_id is not None:
|
|
hierarchy.append(parent_id)
|
|
parent_id = self.dynprompt.get_parent_node_id(parent_id)
|
|
|
|
cache = self
|
|
for parent_id in reversed(hierarchy):
|
|
cache = cache._get_subcache(parent_id)
|
|
if cache is None:
|
|
return None
|
|
return cache
|
|
|
|
async def get(self, node_id):
|
|
cache = self._get_cache_for(node_id)
|
|
if cache is None:
|
|
return None
|
|
return await cache._get_immediate(node_id)
|
|
|
|
def get_local(self, node_id):
|
|
cache = self._get_cache_for(node_id)
|
|
if cache is None:
|
|
return None
|
|
return BasicCache.get_local(cache, node_id)
|
|
|
|
async def set(self, node_id, value):
|
|
cache = self._get_cache_for(node_id)
|
|
assert cache is not None
|
|
await cache._set_immediate(node_id, value)
|
|
|
|
def set_local(self, node_id, value):
|
|
cache = self._get_cache_for(node_id)
|
|
assert cache is not None
|
|
BasicCache.set_local(cache, node_id, value)
|
|
|
|
async def ensure_subcache_for(self, node_id, children_ids):
|
|
cache = self._get_cache_for(node_id)
|
|
assert cache is not None
|
|
return await cache._ensure_subcache(node_id, children_ids)
|
|
|
|
class NullCache:
|
|
|
|
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
|
pass
|
|
|
|
def all_node_ids(self):
|
|
return []
|
|
|
|
def clean_unused(self):
|
|
pass
|
|
|
|
def poll(self, **kwargs):
|
|
pass
|
|
|
|
async def get(self, node_id):
|
|
return None
|
|
|
|
def get_local(self, node_id):
|
|
return None
|
|
|
|
async def set(self, node_id, value):
|
|
pass
|
|
|
|
def set_local(self, node_id, value):
|
|
pass
|
|
|
|
async def ensure_subcache_for(self, node_id, children_ids):
|
|
return self
|
|
|
|
class LRUCache(BasicCache):
|
|
def __init__(self, key_class, max_size=100):
|
|
super().__init__(key_class)
|
|
self.max_size = max_size
|
|
self.min_generation = 0
|
|
self.generation = 0
|
|
self.used_generation = {}
|
|
self.children = {}
|
|
|
|
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
|
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
|
self.generation += 1
|
|
for node_id in node_ids:
|
|
self._mark_used(node_id)
|
|
|
|
def clean_unused(self):
|
|
while len(self.cache) > self.max_size and self.min_generation < self.generation:
|
|
self.min_generation += 1
|
|
to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
|
|
for key in to_remove:
|
|
del self.cache[key]
|
|
del self.used_generation[key]
|
|
if key in self.children:
|
|
del self.children[key]
|
|
self._clean_subcaches()
|
|
|
|
async def get(self, node_id):
|
|
self._mark_used(node_id)
|
|
return await self._get_immediate(node_id)
|
|
|
|
def _mark_used(self, node_id):
|
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
|
if cache_key is not None:
|
|
self.used_generation[cache_key] = self.generation
|
|
|
|
async def set(self, node_id, value):
|
|
self._mark_used(node_id)
|
|
return await self._set_immediate(node_id, value)
|
|
|
|
async def ensure_subcache_for(self, node_id, children_ids):
|
|
# Just uses subcaches for tracking 'live' nodes
|
|
await super()._ensure_subcache(node_id, children_ids)
|
|
|
|
await self.cache_key_set.add_keys(children_ids)
|
|
self._mark_used(node_id)
|
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
|
self.children[cache_key] = []
|
|
for child_id in children_ids:
|
|
self._mark_used(child_id)
|
|
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
|
return self
|
|
|
|
|
|
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
|
|
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
|
|
|
|
RAM_CACHE_HYSTERESIS = 1.1
|
|
|
|
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
|
|
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
|
|
|
|
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
|
|
|
|
#Exponential bias towards evicting older workflows so garbage will be taken out
|
|
#in constantly changing setups.
|
|
|
|
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
|
|
|
|
class RAMPressureCache(LRUCache):
|
|
|
|
def __init__(self, key_class):
|
|
super().__init__(key_class, 0)
|
|
self.timestamps = {}
|
|
|
|
def clean_unused(self):
|
|
self._clean_subcaches()
|
|
|
|
async def set(self, node_id, value):
|
|
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
|
await super().set(node_id, value)
|
|
|
|
async def get(self, node_id):
|
|
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
|
return await super().get(node_id)
|
|
|
|
def poll(self, ram_headroom):
|
|
def _ram_gb():
|
|
return psutil.virtual_memory().available / (1024**3)
|
|
|
|
if _ram_gb() > ram_headroom:
|
|
return
|
|
gc.collect()
|
|
if _ram_gb() > ram_headroom:
|
|
return
|
|
|
|
clean_list = []
|
|
|
|
for key, (outputs, _), in self.cache.items():
|
|
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
|
|
|
|
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
|
|
def scan_list_for_ram_usage(outputs):
|
|
nonlocal ram_usage
|
|
if outputs is None:
|
|
return
|
|
for output in outputs:
|
|
if isinstance(output, list):
|
|
scan_list_for_ram_usage(output)
|
|
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
|
|
#score Tensors at a 50% discount for RAM usage as they are likely to
|
|
#be high value intermediates
|
|
ram_usage += (output.numel() * output.element_size()) * 0.5
|
|
elif hasattr(output, "get_ram_usage"):
|
|
ram_usage += output.get_ram_usage()
|
|
scan_list_for_ram_usage(outputs)
|
|
|
|
oom_score *= ram_usage
|
|
#In the case where we have no information on the node ram usage at all,
|
|
#break OOM score ties on the last touch timestamp (pure LRU)
|
|
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
|
|
|
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
|
|
_, _, key = clean_list.pop()
|
|
del self.cache[key]
|
|
gc.collect()
|