feat: add expected_outputs feature for lazy output computation

This commit is contained in:
bigcat88
2026-02-04 14:26:16 +02:00
parent 2b70ab9ad0
commit d987b0d32d
8 changed files with 515 additions and 18 deletions

View File

@@ -5,7 +5,7 @@ import psutil
import time
import torch
from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt
from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
from abc import ABC, abstractmethod
import nodes
@@ -115,6 +115,10 @@ class CacheKeySetInputSignature(CacheKeySet):
signature = [class_type, await self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id)
# Include expected_outputs in cache key for nodes that opt in via LAZY_OUTPUTS
if hasattr(class_def, 'LAZY_OUTPUTS') and class_def.LAZY_OUTPUTS:
expected = get_expected_outputs_for_node(dynprompt, node_id)
signature.append(("expected_outputs", tuple(sorted(expected))))
inputs = node["inputs"]
for key in sorted(inputs.keys()):
if is_link(inputs[key]):

View File

@@ -19,6 +19,27 @@ class NodeInputError(Exception):
class NodeNotFoundError(Exception):
pass
def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset:
"""Get the set of output indices that are connected downstream.
Returns outputs that MIGHT be used.
Outputs NOT in this set are DEFINITELY not used and safe to skip.
"""
expected = set()
for other_node_id in dynprompt.all_node_ids():
try:
node_data = dynprompt.get_node(other_node_id)
except NodeNotFoundError:
continue
inputs = node_data.get("inputs", {})
for input_name, value in inputs.items():
if is_link(value):
from_node_id, from_socket = value
if from_node_id == node_id:
expected.add(from_socket)
return frozenset(expected)
class DynamicPrompt:
def __init__(self, original_prompt):
# The original prompt provided by the user

View File

@@ -1,21 +1,26 @@
import contextvars
from typing import Optional, NamedTuple
from typing import NamedTuple, FrozenSet
class ExecutionContext(NamedTuple):
"""
Context information about the currently executing node.
Attributes:
prompt_id: The ID of the current prompt execution
node_id: The ID of the currently executing node
list_index: The index in a list being processed (for operations on batches/lists)
expected_outputs: Set of output indices that might be used downstream.
Outputs NOT in this set are definitely unused (safe to skip).
None means the information is not available.
"""
prompt_id: str
node_id: str
list_index: Optional[int]
list_index: int | None
expected_outputs: FrozenSet[int] | None = None
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
current_executing_context: contextvars.ContextVar[ExecutionContext | None] = contextvars.ContextVar("current_executing_context", default=None)
def get_executing_context() -> Optional[ExecutionContext]:
def get_executing_context() -> ExecutionContext | None:
return current_executing_context.get(None)
class CurrentNodeContext:
@@ -25,15 +30,22 @@ class CurrentNodeContext:
Sets the current_executing_context on enter and resets it on exit.
Example:
with CurrentNodeContext(node_id="123", list_index=0):
with CurrentNodeContext(prompt_id="abc", node_id="123", list_index=0):
# Code that should run with the current node context set
process_image()
"""
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
def __init__(
self,
prompt_id: str,
node_id: str,
list_index: int | None = None,
expected_outputs: FrozenSet[int] | None = None,
):
self.context = ExecutionContext(
prompt_id= prompt_id,
node_id= node_id,
list_index= list_index
prompt_id=prompt_id,
node_id=node_id,
list_index=list_index,
expected_outputs=expected_outputs,
)
self.token = None