feat: add expected_outputs feature for lazy output computation

This commit is contained in:
bigcat88
2026-02-04 14:26:16 +02:00
parent 2b70ab9ad0
commit d987b0d32d
8 changed files with 515 additions and 18 deletions

View File

@@ -31,6 +31,7 @@ from comfy_execution.graph import (
ExecutionBlocker,
ExecutionList,
get_input_info,
get_expected_outputs_for_node,
)
from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input
@@ -227,7 +228,18 @@ async def resolve_map_node_over_list_results(results):
raise exc
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
async def _async_map_node_over_list(
prompt_id,
unique_id,
obj,
input_data_all,
func,
allow_interrupt=False,
execution_block_cb=None,
pre_execute_cb=None,
v3_data=None,
expected_outputs=None,
):
# check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
@@ -277,10 +289,12 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
else:
f = getattr(obj, func)
if inspect.iscoroutinefunction(f):
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
with CurrentNodeContext(prompt_id, unique_id, list_index):
async def async_wrapper(f, prompt_id, unique_id, list_index, args, expected_outputs):
with CurrentNodeContext(prompt_id, unique_id, list_index, expected_outputs):
return await f(**args)
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
task = asyncio.create_task(
async_wrapper(f, prompt_id, unique_id, index, args=inputs, expected_outputs=expected_outputs)
)
# Give the task a chance to execute without yielding
await asyncio.sleep(0)
if task.done():
@@ -289,7 +303,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
else:
results.append(task)
else:
with CurrentNodeContext(prompt_id, unique_id, index):
with CurrentNodeContext(prompt_id, unique_id, index, expected_outputs):
result = f(**inputs)
results.append(result)
else:
@@ -327,8 +341,17 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results])
return output
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
async def get_output_data(
prompt_id,
unique_id,
obj,
input_data_all,
execution_block_cb=None,
pre_execute_cb=None,
v3_data=None,
expected_outputs=None,
):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task:
return return_values, {}, False, has_pending_task
@@ -522,9 +545,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
#will cause all sorts of incompatible memory shapes to fragment the pytorch alloc
#that we just want to cull out each model run.
allocator = comfy.memory_management.aimdo_allocator
expected_outputs = get_expected_outputs_for_node(dynprompt, unique_id)
with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())):
try:
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
finally:
if allocator is not None:
comfy.model_management.reset_cast_buffers()