mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-19 06:30:07 +00:00
fix: precompute expected_outputs map to avoid O(n²) graph traversal
This commit is contained in:
@@ -25,19 +25,7 @@ def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset:
|
||||
Returns outputs that MIGHT be used.
|
||||
Outputs NOT in this set are DEFINITELY not used and safe to skip.
|
||||
"""
|
||||
expected = set()
|
||||
for other_node_id in dynprompt.all_node_ids():
|
||||
try:
|
||||
node_data = dynprompt.get_node(other_node_id)
|
||||
except NodeNotFoundError:
|
||||
continue
|
||||
inputs = node_data.get("inputs", {})
|
||||
for input_name, value in inputs.items():
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
if from_node_id == node_id:
|
||||
expected.add(from_socket)
|
||||
return frozenset(expected)
|
||||
return dynprompt.get_expected_outputs_map().get(node_id, frozenset())
|
||||
|
||||
|
||||
class DynamicPrompt:
|
||||
@@ -48,6 +36,7 @@ class DynamicPrompt:
|
||||
self.ephemeral_prompt = {}
|
||||
self.ephemeral_parents = {}
|
||||
self.ephemeral_display = {}
|
||||
self._expected_outputs_map = None
|
||||
|
||||
def get_node(self, node_id):
|
||||
if node_id in self.ephemeral_prompt:
|
||||
@@ -63,6 +52,7 @@ class DynamicPrompt:
|
||||
self.ephemeral_prompt[node_id] = node_info
|
||||
self.ephemeral_parents[node_id] = parent_id
|
||||
self.ephemeral_display[node_id] = display_id
|
||||
self._expected_outputs_map = None
|
||||
|
||||
def get_real_node_id(self, node_id):
|
||||
while node_id in self.ephemeral_parents:
|
||||
@@ -80,6 +70,26 @@ class DynamicPrompt:
|
||||
def all_node_ids(self):
|
||||
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
||||
|
||||
def _build_expected_outputs_map(self):
|
||||
result = {}
|
||||
for node_id in self.all_node_ids():
|
||||
try:
|
||||
node_data = self.get_node(node_id)
|
||||
except NodeNotFoundError:
|
||||
continue
|
||||
for value in node_data.get("inputs", {}).values():
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
if from_node_id not in result:
|
||||
result[from_node_id] = set()
|
||||
result[from_node_id].add(from_socket)
|
||||
self._expected_outputs_map = {k: frozenset(v) for k, v in result.items()}
|
||||
|
||||
def get_expected_outputs_map(self):
|
||||
if self._expected_outputs_map is None:
|
||||
self._build_expected_outputs_map()
|
||||
return self._expected_outputs_map
|
||||
|
||||
def get_original_prompt(self):
|
||||
return self.original_prompt
|
||||
|
||||
|
||||
Reference in New Issue
Block a user