mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-28 19:04:04 +00:00
feat: Add CacheProvider API for external distributed caching
Introduces a public API for external cache providers, enabling distributed caching across multiple ComfyUI instances (e.g., Kubernetes pods). New files: - comfy_execution/cache_provider.py: CacheProvider ABC, CacheContext/CacheValue dataclasses, thread-safe provider registry, serialization utilities Modified files: - comfy_execution/caching.py: Add provider hooks to BasicCache (_notify_providers_store, _check_providers_lookup), subcache exclusion, prompt ID propagation - execution.py: Add prompt lifecycle hooks (on_prompt_start/on_prompt_end) to PromptExecutor, set _current_prompt_id on caches Key features: - Local-first caching (check local before external for performance) - NaN detection to prevent incorrect external cache hits - Subcache exclusion (ephemeral subgraph results not cached externally) - Thread-safe provider snapshot caching - Graceful error handling (provider errors logged, never break execution) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
137
execution.py
137
execution.py
@@ -669,6 +669,22 @@ class PromptExecutor:
|
||||
}
|
||||
self.add_message("execution_error", mes, broadcast=False)
|
||||
|
||||
def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
|
||||
"""Notify external cache providers of prompt lifecycle events."""
|
||||
from comfy_execution.cache_provider import has_cache_providers, get_cache_providers, logger
|
||||
|
||||
if not has_cache_providers():
|
||||
return
|
||||
|
||||
for provider in get_cache_providers():
|
||||
try:
|
||||
if event == "start":
|
||||
provider.on_prompt_start(prompt_id)
|
||||
elif event == "end":
|
||||
provider.on_prompt_end(prompt_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}")
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||
|
||||
@@ -685,66 +701,77 @@ class PromptExecutor:
|
||||
self.status_messages = []
|
||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||
|
||||
with torch.inference_mode():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
add_progress_handler(WebUIProgressHandler(self.server))
|
||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
cache.clean_unused()
|
||||
# Set prompt ID on caches for external provider integration
|
||||
for cache in self.caches.all:
|
||||
cache._current_prompt_id = prompt_id
|
||||
|
||||
cached_nodes = []
|
||||
for node_id in prompt:
|
||||
if self.caches.outputs.get(node_id) is not None:
|
||||
cached_nodes.append(node_id)
|
||||
# Notify external cache providers of prompt start
|
||||
self._notify_prompt_lifecycle("start", prompt_id)
|
||||
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
self.add_message("execution_cached",
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
ui_node_outputs = {}
|
||||
executed = set()
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
for node_id in list(execute_outputs):
|
||||
execution_list.add_node(node_id)
|
||||
try:
|
||||
with torch.inference_mode():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
add_progress_handler(WebUIProgressHandler(self.server))
|
||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
cache.clean_unused()
|
||||
|
||||
while not execution_list.is_empty():
|
||||
node_id, error, ex = await execution_list.stage_node_execution()
|
||||
if error is not None:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
cached_nodes = []
|
||||
for node_id in prompt:
|
||||
if self.caches.outputs.get(node_id) is not None:
|
||||
cached_nodes.append(node_id)
|
||||
|
||||
assert node_id is not None, "Node ID should not be None at this point"
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
elif result == ExecutionResult.PENDING:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
self.add_message("execution_cached",
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
ui_node_outputs = {}
|
||||
executed = set()
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
for node_id in list(execute_outputs):
|
||||
execution_list.add_node(node_id)
|
||||
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
for node_id, ui_info in ui_node_outputs.items():
|
||||
ui_outputs[node_id] = ui_info["output"]
|
||||
meta_outputs[node_id] = ui_info["meta"]
|
||||
self.history_result = {
|
||||
"outputs": ui_outputs,
|
||||
"meta": meta_outputs,
|
||||
}
|
||||
self.server.last_node_id = None
|
||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||
comfy.model_management.unload_all_models()
|
||||
while not execution_list.is_empty():
|
||||
node_id, error, ex = await execution_list.stage_node_execution()
|
||||
if error is not None:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
assert node_id is not None, "Node ID should not be None at this point"
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
elif result == ExecutionResult.PENDING:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
for node_id, ui_info in ui_node_outputs.items():
|
||||
ui_outputs[node_id] = ui_info["output"]
|
||||
meta_outputs[node_id] = ui_info["meta"]
|
||||
self.history_result = {
|
||||
"outputs": ui_outputs,
|
||||
"meta": meta_outputs,
|
||||
}
|
||||
self.server.last_node_id = None
|
||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||
comfy.model_management.unload_all_models()
|
||||
finally:
|
||||
# Notify external cache providers of prompt end
|
||||
self._notify_prompt_lifecycle("end", prompt_id)
|
||||
|
||||
|
||||
async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
|
||||
Reference in New Issue
Block a user