fix: precompute expected_outputs map to avoid O(n²) graph traversal

This commit is contained in:
bigcat88
2026-02-06 08:29:01 +02:00
parent 50975a7a0d
commit 01ef4e50ec
2 changed files with 47 additions and 13 deletions

View File

@@ -25,19 +25,7 @@ def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset:
Returns outputs that MIGHT be used.
Outputs NOT in this set are DEFINITELY not used and safe to skip.
"""
expected = set()
for other_node_id in dynprompt.all_node_ids():
try:
node_data = dynprompt.get_node(other_node_id)
except NodeNotFoundError:
continue
inputs = node_data.get("inputs", {})
for input_name, value in inputs.items():
if is_link(value):
from_node_id, from_socket = value
if from_node_id == node_id:
expected.add(from_socket)
return frozenset(expected)
return dynprompt.get_expected_outputs_map().get(node_id, frozenset())
class DynamicPrompt:
@@ -48,6 +36,7 @@ class DynamicPrompt:
self.ephemeral_prompt = {}
self.ephemeral_parents = {}
self.ephemeral_display = {}
self._expected_outputs_map = None
def get_node(self, node_id):
if node_id in self.ephemeral_prompt:
@@ -63,6 +52,7 @@ class DynamicPrompt:
self.ephemeral_prompt[node_id] = node_info
self.ephemeral_parents[node_id] = parent_id
self.ephemeral_display[node_id] = display_id
self._expected_outputs_map = None
def get_real_node_id(self, node_id):
while node_id in self.ephemeral_parents:
@@ -80,6 +70,26 @@ class DynamicPrompt:
def all_node_ids(self):
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
def _build_expected_outputs_map(self):
result = {}
for node_id in self.all_node_ids():
try:
node_data = self.get_node(node_id)
except NodeNotFoundError:
continue
for value in node_data.get("inputs", {}).values():
if is_link(value):
from_node_id, from_socket = value
if from_node_id not in result:
result[from_node_id] = set()
result[from_node_id].add(from_socket)
self._expected_outputs_map = {k: frozenset(v) for k, v in result.items()}
def get_expected_outputs_map(self):
if self._expected_outputs_map is None:
self._build_expected_outputs_map()
return self._expected_outputs_map
def get_original_prompt(self):
return self.original_prompt