mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-19 06:30:07 +00:00
feat: add expected_outputs feature for lazy output computation
This commit is contained in:
@@ -5,7 +5,7 @@ import psutil
|
||||
import time
|
||||
import torch
|
||||
from typing import Sequence, Mapping, Dict
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import nodes
|
||||
@@ -115,6 +115,10 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
signature = [class_type, await self.is_changed_cache.get(node_id)]
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||
signature.append(node_id)
|
||||
# Include expected_outputs in cache key for nodes that opt in via LAZY_OUTPUTS
|
||||
if hasattr(class_def, 'LAZY_OUTPUTS') and class_def.LAZY_OUTPUTS:
|
||||
expected = get_expected_outputs_for_node(dynprompt, node_id)
|
||||
signature.append(("expected_outputs", tuple(sorted(expected))))
|
||||
inputs = node["inputs"]
|
||||
for key in sorted(inputs.keys()):
|
||||
if is_link(inputs[key]):
|
||||
|
||||
@@ -19,6 +19,27 @@ class NodeInputError(Exception):
|
||||
class NodeNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset:
|
||||
"""Get the set of output indices that are connected downstream.
|
||||
Returns outputs that MIGHT be used.
|
||||
Outputs NOT in this set are DEFINITELY not used and safe to skip.
|
||||
"""
|
||||
expected = set()
|
||||
for other_node_id in dynprompt.all_node_ids():
|
||||
try:
|
||||
node_data = dynprompt.get_node(other_node_id)
|
||||
except NodeNotFoundError:
|
||||
continue
|
||||
inputs = node_data.get("inputs", {})
|
||||
for input_name, value in inputs.items():
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
if from_node_id == node_id:
|
||||
expected.add(from_socket)
|
||||
return frozenset(expected)
|
||||
|
||||
|
||||
class DynamicPrompt:
|
||||
def __init__(self, original_prompt):
|
||||
# The original prompt provided by the user
|
||||
|
||||
@@ -1,21 +1,26 @@
|
||||
import contextvars
|
||||
from typing import Optional, NamedTuple
|
||||
from typing import NamedTuple, FrozenSet
|
||||
|
||||
class ExecutionContext(NamedTuple):
|
||||
"""
|
||||
Context information about the currently executing node.
|
||||
|
||||
Attributes:
|
||||
prompt_id: The ID of the current prompt execution
|
||||
node_id: The ID of the currently executing node
|
||||
list_index: The index in a list being processed (for operations on batches/lists)
|
||||
expected_outputs: Set of output indices that might be used downstream.
|
||||
Outputs NOT in this set are definitely unused (safe to skip).
|
||||
None means the information is not available.
|
||||
"""
|
||||
prompt_id: str
|
||||
node_id: str
|
||||
list_index: Optional[int]
|
||||
list_index: int | None
|
||||
expected_outputs: FrozenSet[int] | None = None
|
||||
|
||||
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
|
||||
current_executing_context: contextvars.ContextVar[ExecutionContext | None] = contextvars.ContextVar("current_executing_context", default=None)
|
||||
|
||||
def get_executing_context() -> Optional[ExecutionContext]:
|
||||
def get_executing_context() -> ExecutionContext | None:
|
||||
return current_executing_context.get(None)
|
||||
|
||||
class CurrentNodeContext:
|
||||
@@ -25,15 +30,22 @@ class CurrentNodeContext:
|
||||
Sets the current_executing_context on enter and resets it on exit.
|
||||
|
||||
Example:
|
||||
with CurrentNodeContext(node_id="123", list_index=0):
|
||||
with CurrentNodeContext(prompt_id="abc", node_id="123", list_index=0):
|
||||
# Code that should run with the current node context set
|
||||
process_image()
|
||||
"""
|
||||
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_id: str,
|
||||
node_id: str,
|
||||
list_index: int | None = None,
|
||||
expected_outputs: FrozenSet[int] | None = None,
|
||||
):
|
||||
self.context = ExecutionContext(
|
||||
prompt_id= prompt_id,
|
||||
node_id= node_id,
|
||||
list_index= list_index
|
||||
prompt_id=prompt_id,
|
||||
node_id=node_id,
|
||||
list_index=list_index,
|
||||
expected_outputs=expected_outputs,
|
||||
)
|
||||
self.token = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user