Compare commits

..

5 Commits

Author SHA1 Message Date
Jacob Segal
0254d9cc11 Add additional tests for async error cases
Also fixes one bug that was found when an async function throws an error
after being scheduled on a task.
2025-07-01 17:13:27 -07:00
Jacob Segal
92f9a10782 Add the websocket library for automated tests 2025-07-01 14:41:52 -07:00
Jacob Segal
a6a6b615f4 Add a missing file
It looks like this got caught by .gitignore? There's probably a better
place to put it, but I'm not sure what that is.
2025-07-01 14:41:52 -07:00
Jacob Segal
50bf72f852 Add the execution model tests to CI 2025-07-01 14:41:52 -07:00
Jacob Segal
46c8311d14 Support for async execution functions
This commit adds support for node execution functions defined as async. When
a node's execution function is defined as async, we can continue
executing other nodes while it is processing.

Standard uses of `await` should "just work", but people will still have
to be careful if they spawn actual threads. Because torch doesn't really
have async/await versions of functions, this won't particularly help
with most locally-executing nodes, but it does work for e.g. web
requests to other machines.

In addition to the execute function, the `VALIDATE_INPUTS` and
`check_lazy_status` functions can also be defined as async, though we'll
only resolve one node at a time right now for those.
2025-07-01 14:41:52 -07:00
38 changed files with 1616 additions and 1093 deletions

View File

@@ -28,3 +28,7 @@ jobs:
run: |
pip install -r tests-unit/requirements.txt
python -m pytest tests-unit
- name: Run Execution Model Tests
run: |
python -m pytest tests/inference/test_execution.py

View File

@@ -7,7 +7,7 @@ on:
description: 'cuda version'
required: true
type: string
default: "129"
default: "128"
python_minor:
description: 'python minor version'
@@ -19,7 +19,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "5"
default: "2"
# push:
# branches:
# - master
@@ -53,8 +53,6 @@ jobs:
ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd

View File

@@ -86,7 +86,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
- Works even if you don't have a GPU with: ```--cpu``` (slow)
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
- Safe loading of ckpt, pt, pth, etc.. files.
- Embeddings/Textual inversion
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
@@ -102,6 +101,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Starts up very fast.
- Works fully offline: core will never download anything unless you want to.
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
@@ -243,7 +243,7 @@ Nvidia users should install stable pytorch using this command:
This is the command to install pytorch nightly instead which might have performance improvements.
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
#### Troubleshooting

View File

@@ -1,10 +1,55 @@
import math
import torch
from torch import nn
from .ldm.modules.attention import CrossAttention, FeedForward
from .ldm.modules.attention import CrossAttention
from inspect import isfunction
import comfy.ops
ops = comfy.ops.manual_cast
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = ops.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * torch.nn.functional.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
ops.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
ops.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
class GatedCrossAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):

View File

@@ -412,13 +412,9 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
ds.pop(0)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
return x
@@ -1071,9 +1067,7 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1)
if t_next == 0: # Denoising step
x_next = denoised
elif order == 1: # First Euler step.
if order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point.
x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
@@ -1091,7 +1085,6 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
return x_next
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license
def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
@@ -1115,9 +1108,7 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1)
if t_next == 0: # Denoising step
x_next = denoised
elif order == 1: # First Euler step.
if order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point.
h_n = (t_next - t_cur)
@@ -1157,7 +1148,6 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
return x_next
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license
@torch.no_grad()
@@ -1208,7 +1198,6 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
return x_next
@torch.no_grad()
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
extra_args = {} if extra_args is None else extra_args
@@ -1415,7 +1404,6 @@ def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=N
def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
@torch.no_grad()
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
@@ -1442,19 +1430,19 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
dt = sigmas[i + 1] - sigmas[i]
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
if i == 0:
# Euler method
if cfg_pp:
x = denoised + d * sigmas[i + 1]
else:
x = x + d * dt
if i >= 1:
# Gradient estimation
else:
# Gradient estimation
if cfg_pp:
d_bar = (ge_gamma - 1) * (d - old_d)
x = denoised + d * sigmas[i + 1] + d_bar * dt
else:
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
x = x + d_bar * dt
old_d = d
return x

View File

@@ -379,9 +379,6 @@ class ModelPatcher:
def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False):
self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization)
def set_model_sampler_calc_cond_batch_function(self, sampler_calc_cond_batch_function):
self.model_options["sampler_calc_cond_batch_function"] = sampler_calc_cond_batch_function
def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
self.model_options["model_function_wrapper"] = unet_wrapper_function

View File

@@ -336,12 +336,9 @@ class fp8_ops(manual_cast):
return None
def forward_comfy_cast_weights(self, input):
try:
out = fp8_linear(self, input)
if out is not None:
return out
except Exception as e:
logging.info("Exception during fp8 op: {}".format(e))
out = fp8_linear(self, input)
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)

View File

@@ -373,11 +373,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
uncond_ = uncond
conds = [cond, uncond_]
if "sampler_calc_cond_batch_function" in model_options:
args = {"conds": conds, "input": x, "sigma": timestep, "model": model, "model_options": model_options}
out = model_options["sampler_calc_cond_batch_function"](args)
else:
out = calc_cond_batch(model, conds, x, timestep, model_options)
out = calc_cond_batch(model, conds, x, timestep, model_options)
for fn in model_options.get("sampler_pre_cfg_function", []):
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,

View File

@@ -77,7 +77,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if safe_load or ALWAYS_SAFE_LOAD:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else:
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
@@ -998,11 +997,12 @@ def set_progress_bar_global_hook(function):
PROGRESS_BAR_HOOK = function
class ProgressBar:
def __init__(self, total):
def __init__(self, total, node_id=None):
global PROGRESS_BAR_HOOK
self.total = total
self.current = 0
self.hook = PROGRESS_BAR_HOOK
self.node_id = node_id
def update_absolute(self, value, total=None, preview=None):
if total is not None:
@@ -1011,7 +1011,7 @@ class ProgressBar:
value = self.total
self.current = value
if self.hook is not None:
self.hook(self.current, self.total, preview)
self.hook(self.current, self.total, preview, node_id=self.node_id)
def update(self, value):
self.update_absolute(self.current + value)

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional, Union
import io
from typing import Optional
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
class VideoInput(ABC):
@@ -32,22 +31,6 @@ class VideoInput(ABC):
"""
pass
def get_stream_source(self) -> Union[str, io.BytesIO]:
"""
Get a streamable source for the video. This allows processing without
loading the entire video into memory.
Returns:
Either a file path (str) or a BytesIO object that can be opened with av.
Default implementation creates a BytesIO buffer, but subclasses should
override this for better performance when possible.
"""
buffer = io.BytesIO()
self.save_to(buffer)
buffer.seek(0)
return buffer
# Provide a default implementation, but subclasses can provide optimized versions
# if possible.
def get_dimensions(self) -> tuple[int, int]:

View File

@@ -64,15 +64,6 @@ class VideoFromFile(VideoInput):
"""
self.__file = file
def get_stream_source(self) -> str | io.BytesIO:
"""
Return the underlying file source for efficient streaming.
This avoids unnecessary memory copies when the source is already a file path.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
return self.__file
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.

View File

@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: filtered-openapi.yaml
# timestamp: 2025-07-06T09:47:31+00:00
# timestamp: 2025-05-19T21:38:55+00:00
from __future__ import annotations
@@ -1355,158 +1355,6 @@ class ModelResponseProperties(BaseModel):
)
class Keyframes(BaseModel):
image_url: Optional[str] = None
class MoonvalleyPromptResponse(BaseModel):
error: Optional[Dict[str, Any]] = None
frame_conditioning: Optional[Dict[str, Any]] = None
id: Optional[str] = None
inference_params: Optional[Dict[str, Any]] = None
meta: Optional[Dict[str, Any]] = None
model_params: Optional[Dict[str, Any]] = None
output_url: Optional[str] = None
prompt_text: Optional[str] = None
status: Optional[str] = None
class MoonvalleyTextToVideoInferenceParams(BaseModel):
add_quality_guidance: Optional[bool] = Field(
True, description='Whether to add quality guidance'
)
caching_coefficient: Optional[float] = Field(
0.3, description='Caching coefficient for optimization'
)
caching_cooldown: Optional[int] = Field(
3, description='Number of caching cooldown steps'
)
caching_warmup: Optional[int] = Field(
3, description='Number of caching warmup steps'
)
clip_value: Optional[float] = Field(
3, description='CLIP value for generation control'
)
conditioning_frame_index: Optional[int] = Field(
0, description='Index of the conditioning frame'
)
cooldown_steps: Optional[int] = Field(
None, description='Number of cooldown steps (calculated based on num_frames)'
)
fps: Optional[int] = Field(
24, description='Frames per second of the generated video'
)
guidance_scale: Optional[float] = Field(
12.5, description='Guidance scale for generation control'
)
height: Optional[int] = Field(
1080, description='Height of the generated video in pixels'
)
negative_prompt: Optional[str] = Field(None, description='Negative prompt text')
num_frames: Optional[int] = Field(64, description='Number of frames to generate')
seed: Optional[int] = Field(
None, description='Random seed for generation (default: random)'
)
shift_value: Optional[float] = Field(
3, description='Shift value for generation control'
)
steps: Optional[int] = Field(80, description='Number of denoising steps')
use_guidance_schedule: Optional[bool] = Field(
True, description='Whether to use guidance scheduling'
)
use_negative_prompts: Optional[bool] = Field(
False, description='Whether to use negative prompts'
)
use_timestep_transform: Optional[bool] = Field(
True, description='Whether to use timestep transformation'
)
warmup_steps: Optional[int] = Field(
None, description='Number of warmup steps (calculated based on num_frames)'
)
width: Optional[int] = Field(
1920, description='Width of the generated video in pixels'
)
class MoonvalleyTextToVideoRequest(BaseModel):
image_url: Optional[str] = None
inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None
prompt_text: Optional[str] = None
webhook_url: Optional[str] = None
class MoonvalleyUploadFileRequest(BaseModel):
file: Optional[StrictBytes] = None
class MoonvalleyUploadFileResponse(BaseModel):
access_url: Optional[str] = None
class MoonvalleyVideoToVideoInferenceParams(BaseModel):
add_quality_guidance: Optional[bool] = Field(
True, description='Whether to add quality guidance'
)
caching_coefficient: Optional[float] = Field(
0.3, description='Caching coefficient for optimization'
)
caching_cooldown: Optional[int] = Field(
3, description='Number of caching cooldown steps'
)
caching_warmup: Optional[int] = Field(
3, description='Number of caching warmup steps'
)
clip_value: Optional[float] = Field(
3, description='CLIP value for generation control'
)
conditioning_frame_index: Optional[int] = Field(
0, description='Index of the conditioning frame'
)
cooldown_steps: Optional[int] = Field(
None, description='Number of cooldown steps (calculated based on num_frames)'
)
guidance_scale: Optional[float] = Field(
12.5, description='Guidance scale for generation control'
)
negative_prompt: Optional[str] = Field(None, description='Negative prompt text')
seed: Optional[int] = Field(
None, description='Random seed for generation (default: random)'
)
shift_value: Optional[float] = Field(
3, description='Shift value for generation control'
)
steps: Optional[int] = Field(80, description='Number of denoising steps')
use_guidance_schedule: Optional[bool] = Field(
True, description='Whether to use guidance scheduling'
)
use_negative_prompts: Optional[bool] = Field(
False, description='Whether to use negative prompts'
)
use_timestep_transform: Optional[bool] = Field(
True, description='Whether to use timestep transformation'
)
warmup_steps: Optional[int] = Field(
None, description='Number of warmup steps (calculated based on num_frames)'
)
class ControlType(str, Enum):
motion_control = 'motion_control'
pose_control = 'pose_control'
class MoonvalleyVideoToVideoRequest(BaseModel):
control_type: ControlType = Field(
..., description='Supported types for video control'
)
inference_params: Optional[MoonvalleyVideoToVideoInferenceParams] = None
prompt_text: str = Field(..., description='Describes the video to generate')
video_url: str = Field(..., description='Url to control video')
webhook_url: Optional[str] = Field(
None, description='Optional webhook URL for notifications'
)
class Moderation(str, Enum):
low = 'low'
auto = 'auto'
@@ -3259,23 +3107,6 @@ class LumaUpscaleVideoGenerationRequest(BaseModel):
resolution: Optional[LumaVideoModelOutputResolution] = None
class MoonvalleyImageToVideoRequest(MoonvalleyTextToVideoRequest):
keyframes: Optional[Dict[str, Keyframes]] = None
class MoonvalleyResizeVideoRequest(MoonvalleyVideoToVideoRequest):
frame_position: Optional[List[int]] = Field(None, max_length=2, min_length=2)
frame_resolution: Optional[List[int]] = Field(None, max_length=2, min_length=2)
scale: Optional[List[int]] = Field(None, max_length=2, min_length=2)
class MoonvalleyTextToImageRequest(BaseModel):
image_url: Optional[str] = None
inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None
prompt_text: Optional[str] = None
webhook_url: Optional[str] = None
class OutputContent(RootModel[Union[OutputTextContent, OutputAudioContent]]):
root: Union[OutputTextContent, OutputAudioContent]

View File

@@ -1,639 +0,0 @@
import logging
from typing import Any, Callable, Optional, TypeVar
import random
import torch
from comfy_api_nodes.util.validation_utils import get_image_dimensions, validate_image_dimensions, validate_video_dimensions
from comfy_api_nodes.apis import (
MoonvalleyTextToVideoRequest,
MoonvalleyTextToVideoInferenceParams,
MoonvalleyVideoToVideoInferenceParams,
MoonvalleyVideoToVideoRequest,
MoonvalleyPromptResponse
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_video_output,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy_api.input.video_types import VideoInput
from comfy.comfy_types.node_typing import IO
from comfy_api.input_impl import VideoFromFile
import av
import io
API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads"
API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts"
API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video"
API_TXT2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/text-to-video"
API_IMG2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/image-to-video"
MIN_WIDTH = 300
MIN_HEIGHT = 300
MAX_WIDTH = 10000
MAX_HEIGHT = 10000
MIN_VID_WIDTH = 300
MIN_VID_HEIGHT = 300
MAX_VID_WIDTH = 10000
MAX_VID_HEIGHT = 10000
MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing
MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000
R = TypeVar("R")
class MoonvalleyApiError(Exception):
"""Base exception for Moonvalley API errors."""
pass
def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool:
"""Verifies that the initial response contains a task ID."""
return bool(response.id)
def validate_task_creation_response(response) -> None:
if not is_valid_task_creation_response(response):
error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}"
logging.error(error_msg)
raise MoonvalleyApiError(error_msg)
def get_video_from_response(response):
video = response.output_url
logging.info(
"Moonvalley Marey API: Task %s succeeded. Video URL: %s", response.id, video
)
return video
def get_video_url_from_response(response) -> Optional[str]:
"""Returns the first video url from the Moonvalley video generation task result.
Will not raise an error if the response is not valid.
"""
if response:
return str(get_video_from_response(response))
else:
return None
def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None,
node_id: Optional[str] = None,
) -> R:
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
return PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[
"completed",
],
max_poll_attempts=240, # 64 minutes with 16s interval
poll_interval=16.0,
failed_statuses=["error"],
status_extractor=lambda response: (
response.status
if response and response.status
else None
),
auth_kwargs=auth_kwargs,
result_url_extractor=result_url_extractor,
node_id=node_id,
).execute()
def validate_prompts(prompt:str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH):
"""Verifies that the prompt isn't empty and that neither prompt is too long."""
if not prompt:
raise ValueError("Positive prompt is empty")
if len(prompt) > max_length:
raise ValueError(f"Positive prompt is too long: {len(prompt)} characters")
if negative_prompt and len(negative_prompt) > max_length:
raise ValueError(
f"Negative prompt is too long: {len(negative_prompt)} characters"
)
return True
def validate_input_media(width, height, with_frame_conditioning, num_frames_in=None):
# inference validation
# T = num_frames
# in all cases, the following must be true: T divisible by 16 and H,W by 8. in addition...
# with image conditioning: H*W must be divisible by 8192
# without image conditioning: T divisible by 32
if num_frames_in and not num_frames_in % 16 == 0 :
return False, (
"The input video total frame count must be divisible by 16!"
)
if height % 8 != 0 or width % 8 != 0:
return False, (
f"Height ({height}) and width ({width}) must be " "divisible by 8"
)
if with_frame_conditioning:
if (height * width) % 8192 != 0:
return False, (
f"Height * width ({height * width}) must be "
"divisible by 8192 for frame conditioning"
)
else:
if num_frames_in and not num_frames_in % 32 == 0 :
return False, (
"The input video total frame count must be divisible by 32!"
)
def validate_input_image(image: torch.Tensor, with_frame_conditioning: bool=False) -> None:
"""
Validates the input image adheres to the expectations of the API:
- The image resolution should not be less than 300*300px
- The aspect ratio of the image should be between 1:2.5 ~ 2.5:1
"""
height, width = get_image_dimensions(image)
validate_input_media(width, height, with_frame_conditioning )
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
def validate_input_video(video: VideoInput, num_frames_out: int, with_frame_conditioning: bool=False):
try:
width, height = video.get_dimensions()
except Exception as e:
logging.error("Error getting dimensions of video: %s", e)
raise ValueError(f"Cannot get video dimensions: {e}") from e
validate_input_media(width, height, with_frame_conditioning)
validate_video_dimensions(video, min_width=MIN_VID_WIDTH, min_height=MIN_VID_HEIGHT, max_width=MAX_VID_WIDTH, max_height=MAX_VID_HEIGHT)
trimmed_video = validate_input_video_length(video, num_frames_out)
return trimmed_video
def validate_input_video_length(video: VideoInput, num_frames: int):
if video.get_duration() > 60:
raise MoonvalleyApiError("Input Video lenth should be less than 1min. Please trim.")
if num_frames == 128:
if video.get_duration() < 5:
raise MoonvalleyApiError("Input Video length is less than 5s. Please use a video longer than or equal to 5s.")
if video.get_duration() > 5:
# trim video to 5s
video = trim_video(video, 5)
if num_frames == 256:
if video.get_duration() < 10:
raise MoonvalleyApiError("Input Video length is less than 10s. Please use a video longer than or equal to 10s.")
if video.get_duration() > 10:
# trim video to 10s
video = trim_video(video, 10)
return video
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
"""
Returns a new VideoInput object trimmed from the beginning to the specified duration,
using av to avoid loading entire video into memory.
Args:
video: Input video to trim
duration_sec: Duration in seconds to keep from the beginning
Returns:
VideoFromFile object that owns the output buffer
"""
output_buffer = io.BytesIO()
input_container = None
output_container = None
try:
# Get the stream source - this avoids loading entire video into memory
# when the source is already a file path
input_source = video.get_stream_source()
# Open containers
input_container = av.open(input_source, mode='r')
output_container = av.open(output_buffer, mode='w', format='mp4')
# Set up output streams for re-encoding
video_stream = None
audio_stream = None
for stream in input_container.streams:
logging.info(f"Found stream: type={stream.type}, class={type(stream)}")
if isinstance(stream, av.VideoStream):
# Create output video stream with same parameters
video_stream = output_container.add_stream('h264', rate=stream.average_rate)
video_stream.width = stream.width
video_stream.height = stream.height
video_stream.pix_fmt = 'yuv420p'
logging.info(f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps")
elif isinstance(stream, av.AudioStream):
# Create output audio stream with same parameters
audio_stream = output_container.add_stream('aac', rate=stream.sample_rate)
audio_stream.sample_rate = stream.sample_rate
audio_stream.layout = stream.layout
logging.info(f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels")
# Calculate target frame count that's divisible by 32
fps = input_container.streams.video[0].average_rate
estimated_frames = int(duration_sec * fps)
target_frames = (estimated_frames // 32) * 32 # Round down to nearest multiple of 32
if target_frames == 0:
raise ValueError("Video too short: need at least 32 frames for Moonvalley")
frame_count = 0
audio_frame_count = 0
# Decode and re-encode video frames
if video_stream:
for frame in input_container.decode(video=0):
if frame_count >= target_frames:
break
# Re-encode frame
for packet in video_stream.encode(frame):
output_container.mux(packet)
frame_count += 1
# Flush encoder
for packet in video_stream.encode():
output_container.mux(packet)
logging.info(f"Encoded {frame_count} video frames (target: {target_frames})")
# Decode and re-encode audio frames
if audio_stream:
input_container.seek(0) # Reset to beginning for audio
for frame in input_container.decode(audio=0):
if frame.time >= duration_sec:
break
# Re-encode frame
for packet in audio_stream.encode(frame):
output_container.mux(packet)
audio_frame_count += 1
# Flush encoder
for packet in audio_stream.encode():
output_container.mux(packet)
logging.info(f"Encoded {audio_frame_count} audio frames")
# Close containers
output_container.close()
input_container.close()
# Return as VideoFromFile using the buffer
output_buffer.seek(0)
return VideoFromFile(output_buffer)
except Exception as e:
# Clean up on error
if input_container is not None:
input_container.close()
if output_container is not None:
output_container.close()
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
# --- BaseMoonvalleyVideoNode ---
class BaseMoonvalleyVideoNode:
def parseWidthHeightFromRes(self, resolution: str):
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
res_map = {
"16:9 (1920 x 1080)": {"width": 1920, "height": 1080},
"9:16 (1080 x 1920)": {"width": 1080, "height": 1920},
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
"4:3 (1440 x 1080)": {"width": 1440, "height": 1080},
"3:4 (1080 x 1440)": {"width": 1080, "height": 1440},
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
}
if resolution in res_map:
return res_map[resolution]
else:
# Default to 1920x1080 if unknown
return {"width": 1920, "height": 1080}
def parseControlParameter(self, value):
control_map = {
"Motion Transfer": "motion_control",
"Canny": "canny_control",
"Pose Transfer": "pose_control",
"Depth": "depth_control"
}
if value in control_map:
return control_map[value]
else:
return control_map["Motion Transfer"]
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> MoonvalleyPromptResponse:
return poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MoonvalleyPromptResponse,
),
result_url_extractor=get_video_url_from_response,
node_id=node_id,
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, MoonvalleyTextToVideoRequest, "prompt_text",
multiline=True
),
"negative_prompt": model_field_to_node_input(
IO.STRING,
MoonvalleyTextToVideoInferenceParams,
"negative_prompt",
multiline=True,
default="gopro, bright, contrast, static, overexposed, bright, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, contrast, saturated, vibrant, glowing, cross dissolve, texture, videogame, saturation, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, transition, dissolve, cross-dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring, static",
),
"resolution": (IO.COMBO, {
"options": ["16:9 (1920 x 1080)",
"9:16 (1080 x 1920)",
"1:1 (1152 x 1152)",
"4:3 (1440 x 1080)",
"3:4 (1080 x 1440)",
"21:9 (2560 x 1080)"],
"default": "16:9 (1920 x 1080)",
"tooltip": "Resolution of the output video",
}),
# "length": (IO.COMBO,{"options":['5s','10s'], "default": '5s'}),
"prompt_adherence": model_field_to_node_input(IO.FLOAT,MoonvalleyTextToVideoInferenceParams,"guidance_scale",default=7.0, step=1, min=1, max=20),
"seed": model_field_to_node_input(IO.INT,MoonvalleyTextToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True),
"steps": model_field_to_node_input(IO.INT, MoonvalleyTextToVideoInferenceParams, "steps", default=100, min=1, max=100),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
"optional": {
"image": model_field_to_node_input(
IO.IMAGE,
MoonvalleyTextToVideoRequest,
"image_url",
tooltip="The reference image used to generate the video",
),
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "generate"
CATEGORY = "api node/video/Moonvalley Marey"
API_NODE = True
def generate(self, **kwargs):
return None
# --- MoonvalleyImg2VideoNode ---
class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
@classmethod
def INPUT_TYPES(cls):
return super().INPUT_TYPES()
RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",)
DESCRIPTION = "Moonvalley Marey Image to Video Node"
def generate(self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs):
image = kwargs.get("image", None)
if (image is None):
raise MoonvalleyApiError("image is required")
total_frames = get_total_frames_from_length()
validate_input_image(image,True)
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
inference_params=MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
num_frames=total_frames,
width=width_height.get("width"),
height=width_height.get("height"),
use_negative_prompts=True
)
"""Upload image to comfy backend to have a URL available for further processing"""
# Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png"
image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type)[0]
request = MoonvalleyTextToVideoRequest(
image_url=image_url,
prompt_text=prompt,
inference_params=inference_params
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyTextToVideoRequest,
response_model=MoonvalleyPromptResponse
),
request=request,
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
video = download_url_to_video_output(final_response.output_url)
return (video, )
# --- MoonvalleyVid2VidNode ---
class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
def __init__(self):
super().__init__()
@classmethod
def INPUT_TYPES(cls):
input_types = super().INPUT_TYPES()
for param in ["resolution", "image"]:
if param in input_types["required"]:
del input_types["required"][param]
if param in input_types["optional"]:
del input_types["optional"][param]
input_types["optional"] = {
"video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Input a 5s video for 128 frames and a 10s video for 256 frames. Longer videos will be trimmed automatically."}),
"control_type": (
["Motion Transfer", "Pose Transfer"],
{"default": "Motion Transfer"},
),
"motion_intensity": (
"INT",
{
"default": 100,
"step": 1,
"min": 0,
"max": 100,
"tooltip": "Only used if control_type is 'Motion Transfer'",
},
)
}
return input_types
RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",)
def generate(self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs):
video = kwargs.get("video")
num_frames = get_total_frames_from_length()
if not video :
raise MoonvalleyApiError("video is required")
"""Validate video input"""
video_url=""
if video:
validated_video = validate_input_video(video, num_frames, False)
video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
control_type = kwargs.get("control_type")
motion_intensity = kwargs.get("motion_intensity")
"""Validate prompts and inference input"""
validate_prompts(prompt, negative_prompt)
inference_params=MoonvalleyVideoToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
control_params={'motion_intensity': motion_intensity}
)
control = self.parseControlParameter(control_type)
request = MoonvalleyVideoToVideoRequest(
control_type=control,
video_url=video_url,
prompt_text=prompt,
inference_params=inference_params
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyVideoToVideoRequest,
response_model=MoonvalleyPromptResponse
),
request=request,
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
video = download_url_to_video_output(final_response.output_url)
return (video, )
# --- MoonvalleyTxt2VideoNode ---
class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
def __init__(self):
super().__init__()
RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",)
@classmethod
def INPUT_TYPES(cls):
input_types = super().INPUT_TYPES()
# Remove image-specific parameters
for param in ["image"]:
if param in input_types["optional"]:
del input_types["optional"][param]
return input_types
def generate(self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs):
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
num_frames = get_total_frames_from_length()
inference_params=MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
num_frames=num_frames,
width=width_height.get("width"),
height=width_height.get("height"),
)
request = MoonvalleyTextToVideoRequest(
prompt_text=prompt,
inference_params=inference_params
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyTextToVideoRequest,
response_model=MoonvalleyPromptResponse
),
request=request,
auth_kwargs=kwargs,
)
task_creation_response = initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
video = download_url_to_video_output(final_response.output_url)
return (video, )
NODE_CLASS_MAPPINGS = {
"MoonvalleyImg2VideoNode": MoonvalleyImg2VideoNode,
"MoonvalleyTxt2VideoNode": MoonvalleyTxt2VideoNode,
# "MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"MoonvalleyImg2VideoNode": "Moonvalley Marey Image to Video",
"MoonvalleyTxt2VideoNode": "Moonvalley Marey Text to Video",
# "MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video",
}
def get_total_frames_from_length(length="5s"):
# if length == '5s':
# return 128
# elif length == '10s':
# return 256
return 128
# else:
# raise MoonvalleyApiError("length is required")

View File

@@ -1,6 +1,7 @@
import itertools
from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt
from abc import ABC, abstractmethod
import nodes
@@ -16,12 +17,13 @@ def include_unique_id_in_input(class_type: str) -> bool:
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class CacheKeySet:
class CacheKeySet(ABC):
def __init__(self, dynprompt, node_ids, is_changed_cache):
self.keys = {}
self.subcache_keys = {}
def add_keys(self, node_ids):
@abstractmethod
async def add_keys(self, node_ids):
raise NotImplementedError()
def all_node_ids(self):
@@ -60,9 +62,8 @@ class CacheKeySetID(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.add_keys(node_ids)
def add_keys(self, node_ids):
async def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
@@ -77,37 +78,36 @@ class CacheKeySetInputSignature(CacheKeySet):
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.is_changed_cache = is_changed_cache
self.add_keys(node_ids)
def include_node_id_in_input(self) -> bool:
return False
def add_keys(self, node_ids):
async def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
if not self.dynprompt.has_node(node_id):
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id)
self.subcache_keys[node_id] = (node_id, node["class_type"])
def get_node_signature(self, dynprompt, node_id):
async def get_node_signature(self, dynprompt, node_id):
signature = []
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
for ancestor_id in ancestors:
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
return to_hashable(signature)
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
if not dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it.
return [float("NaN")]
node = dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)]
signature = [class_type, await self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id)
inputs = node["inputs"]
@@ -150,9 +150,10 @@ class BasicCache:
self.cache = {}
self.subcaches = {}
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
await self.cache_key_set.add_keys(node_ids)
self.is_changed_cache = is_changed_cache
self.initialized = True
@@ -201,13 +202,13 @@ class BasicCache:
else:
return None
def _ensure_subcache(self, node_id, children_ids):
async def _ensure_subcache(self, node_id, children_ids):
subcache_key = self.cache_key_set.get_subcache_key(node_id)
subcache = self.subcaches.get(subcache_key, None)
if subcache is None:
subcache = BasicCache(self.key_class)
self.subcaches[subcache_key] = subcache
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
return subcache
def _get_subcache(self, node_id):
@@ -259,10 +260,10 @@ class HierarchicalCache(BasicCache):
assert cache is not None
cache._set_immediate(node_id, value)
def ensure_subcache_for(self, node_id, children_ids):
async def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id)
assert cache is not None
return cache._ensure_subcache(node_id, children_ids)
return await cache._ensure_subcache(node_id, children_ids)
class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100):
@@ -273,8 +274,8 @@ class LRUCache(BasicCache):
self.used_generation = {}
self.children = {}
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
super().set_prompt(dynprompt, node_ids, is_changed_cache)
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
self.generation += 1
for node_id in node_ids:
self._mark_used(node_id)
@@ -303,11 +304,11 @@ class LRUCache(BasicCache):
self._mark_used(node_id)
return self._set_immediate(node_id, value)
def ensure_subcache_for(self, node_id, children_ids):
async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
super()._ensure_subcache(node_id, children_ids)
await super()._ensure_subcache(node_id, children_ids)
self.cache_key_set.add_keys(children_ids)
await self.cache_key_set.add_keys(children_ids)
self._mark_used(node_id)
cache_key = self.cache_key_set.get_data_key(node_id)
self.children[cache_key] = []
@@ -337,7 +338,7 @@ class DependencyAwareCache(BasicCache):
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
self.executed_nodes = set() # Tracks nodes that have been executed
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
"""
Clear the entire cache and rebuild the dependency graph.
@@ -354,7 +355,7 @@ class DependencyAwareCache(BasicCache):
self.executed_nodes.clear()
# Call the parent method to initialize the cache with the new prompt
super().set_prompt(dynprompt, node_ids, is_changed_cache)
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
# Rebuild the dependency graph
self._build_dependency_graph(dynprompt, node_ids)
@@ -405,7 +406,7 @@ class DependencyAwareCache(BasicCache):
"""
return self._get_immediate(node_id)
def ensure_subcache_for(self, node_id, children_ids):
async def ensure_subcache_for(self, node_id, children_ids):
"""
Ensure a subcache exists for a node and update dependencies.
@@ -416,7 +417,7 @@ class DependencyAwareCache(BasicCache):
Returns:
The subcache object for the node.
"""
subcache = super()._ensure_subcache(node_id, children_ids)
subcache = await super()._ensure_subcache(node_id, children_ids)
for child_id in children_ids:
self.descendants[node_id].add(child_id)
self.ancestors[child_id].add(node_id)

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Type, Literal
import nodes
import asyncio
from comfy_execution.graph_utils import is_link
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
@@ -100,6 +101,8 @@ class TopologicalSort:
self.pendingNodes = {}
self.blockCount = {} # Number of nodes this node is directly blocked by
self.blocking = {} # Which nodes are blocked by this node
self.externalBlocks = 0
self.unblockedEvent = asyncio.Event()
def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"]
@@ -153,6 +156,16 @@ class TopologicalSort:
for link in links:
self.add_strong_link(*link)
def add_external_block(self, node_id):
assert node_id in self.blockCount, "Can't add external block to a node that isn't pending"
self.externalBlocks += 1
self.blockCount[node_id] += 1
def unblock():
self.externalBlocks -= 1
self.blockCount[node_id] -= 1
self.unblockedEvent.set()
return unblock
def is_cached(self, node_id):
return False
@@ -181,11 +194,16 @@ class ExecutionList(TopologicalSort):
def is_cached(self, node_id):
return self.output_cache.get(node_id) is not None
def stage_node_execution(self):
async def stage_node_execution(self):
assert self.staged_node_id is None
if self.is_empty():
return None, None, None
available = self.get_ready_nodes()
while len(available) == 0 and self.externalBlocks > 0:
# Wait for an external block to be released
await self.unblockedEvent.wait()
self.unblockedEvent.clear()
available = self.get_ready_nodes()
if len(available) == 0:
cycled_nodes = self.get_nodes_in_cycle()
# Because cycles composed entirely of static nodes are caught during initial validation,

288
comfy_execution/progress.py Normal file
View File

@@ -0,0 +1,288 @@
from typing import TypedDict, Dict, Optional
from typing_extensions import override
from PIL import Image
from enum import Enum
from abc import ABC
from tqdm import tqdm
from comfy_execution.graph import DynamicPrompt
from protocol import BinaryEventTypes
class NodeState(Enum):
Pending = "pending"
Running = "running"
Finished = "finished"
Error = "error"
class NodeProgressState(TypedDict):
"""
A class to represent the state of a node's progress.
"""
state: NodeState
value: float
max: float
class ProgressHandler(ABC):
"""
Abstract base class for progress handlers.
Progress handlers receive progress updates and display them in various ways.
"""
def __init__(self, name: str):
self.name = name
self.enabled = True
def set_registry(self, registry: "ProgressRegistry"):
pass
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
"""Called when a node starts processing"""
pass
def update_handler(self, node_id: str, value: float, max_value: float,
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
"""Called when a node's progress is updated"""
pass
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
"""Called when a node finishes processing"""
pass
def reset(self):
"""Called when the progress registry is reset"""
pass
def enable(self):
"""Enable this handler"""
self.enabled = True
def disable(self):
"""Disable this handler"""
self.enabled = False
class CLIProgressHandler(ProgressHandler):
"""
Handler that displays progress using tqdm progress bars in the CLI.
"""
def __init__(self):
super().__init__("cli")
self.progress_bars: Dict[str, tqdm] = {}
@override
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
# Create a new tqdm progress bar
if node_id not in self.progress_bars:
self.progress_bars[node_id] = tqdm(
total=state["max"],
desc=f"Node {node_id}",
unit="steps",
leave=True,
position=len(self.progress_bars)
)
@override
def update_handler(self, node_id: str, value: float, max_value: float,
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
# Handle case where start_handler wasn't called
if node_id not in self.progress_bars:
self.progress_bars[node_id] = tqdm(
total=max_value,
desc=f"Node {node_id}",
unit="steps",
leave=True,
position=len(self.progress_bars)
)
self.progress_bars[node_id].update(value)
else:
# Update existing progress bar
if max_value != self.progress_bars[node_id].total:
self.progress_bars[node_id].total = max_value
# Calculate the update amount (difference from current position)
current_position = self.progress_bars[node_id].n
update_amount = value - current_position
if update_amount > 0:
self.progress_bars[node_id].update(update_amount)
@override
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
# Complete and close the progress bar if it exists
if node_id in self.progress_bars:
# Ensure the bar shows 100% completion
remaining = state["max"] - self.progress_bars[node_id].n
if remaining > 0:
self.progress_bars[node_id].update(remaining)
self.progress_bars[node_id].close()
del self.progress_bars[node_id]
@override
def reset(self):
# Close all progress bars
for bar in self.progress_bars.values():
bar.close()
self.progress_bars.clear()
class WebUIProgressHandler(ProgressHandler):
"""
Handler that sends progress updates to the WebUI via WebSockets.
"""
def __init__(self, server_instance):
super().__init__("webui")
self.server_instance = server_instance
def set_registry(self, registry: "ProgressRegistry"):
self.registry = registry
def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]):
"""Send the current progress state to the client"""
if self.server_instance is None:
return
# Only send info for non-pending nodes
active_nodes = {
node_id: {
"value": state["value"],
"max": state["max"],
"state": state["state"].value,
"node_id": node_id,
"prompt_id": prompt_id,
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id)
}
for node_id, state in nodes.items()
if state["state"] != NodeState.Pending
}
# Send a combined progress_state message with all node states
self.server_instance.send_sync("progress_state", {
"prompt_id": prompt_id,
"nodes": active_nodes
})
@override
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
# Send progress state of all nodes
if self.registry:
self._send_progress_state(prompt_id, self.registry.nodes)
@override
def update_handler(self, node_id: str, value: float, max_value: float,
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
# Send progress state of all nodes
if self.registry:
self._send_progress_state(prompt_id, self.registry.nodes)
if image:
metadata = {
"node_id": node_id,
"prompt_id": prompt_id,
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id)
}
self.server_instance.send_sync(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, (image, metadata), self.server_instance.client_id)
@override
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
# Send progress state of all nodes
if self.registry:
self._send_progress_state(prompt_id, self.registry.nodes)
class ProgressRegistry:
"""
Registry that maintains node progress state and notifies registered handlers.
"""
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt):
self.prompt_id = prompt_id
self.dynprompt = dynprompt
self.nodes: Dict[str, NodeProgressState] = {}
self.handlers: Dict[str, ProgressHandler] = {}
def register_handler(self, handler: ProgressHandler) -> None:
"""Register a progress handler"""
self.handlers[handler.name] = handler
def unregister_handler(self, handler_name: str) -> None:
"""Unregister a progress handler"""
if handler_name in self.handlers:
# Allow handler to clean up resources
self.handlers[handler_name].reset()
del self.handlers[handler_name]
def enable_handler(self, handler_name: str) -> None:
"""Enable a progress handler"""
if handler_name in self.handlers:
self.handlers[handler_name].enable()
def disable_handler(self, handler_name: str) -> None:
"""Disable a progress handler"""
if handler_name in self.handlers:
self.handlers[handler_name].disable()
def ensure_entry(self, node_id: str) -> NodeProgressState:
"""Ensure a node entry exists"""
if node_id not in self.nodes:
self.nodes[node_id] = NodeProgressState(
state = NodeState.Pending,
value = 0,
max = 1
)
return self.nodes[node_id]
def start_progress(self, node_id: str) -> None:
"""Start progress tracking for a node"""
entry = self.ensure_entry(node_id)
entry["state"] = NodeState.Running
entry["value"] = 0.0
entry["max"] = 1.0
# Notify all enabled handlers
for handler in self.handlers.values():
if handler.enabled:
handler.start_handler(node_id, entry, self.prompt_id)
def update_progress(self, node_id: str, value: float, max_value: float, image: Optional[Image.Image]) -> None:
"""Update progress for a node"""
entry = self.ensure_entry(node_id)
entry["state"] = NodeState.Running
entry["value"] = value
entry["max"] = max_value
# Notify all enabled handlers
for handler in self.handlers.values():
if handler.enabled:
handler.update_handler(node_id, value, max_value, entry, self.prompt_id, image)
def finish_progress(self, node_id: str) -> None:
"""Finish progress tracking for a node"""
entry = self.ensure_entry(node_id)
entry["state"] = NodeState.Finished
entry["value"] = entry["max"]
# Notify all enabled handlers
for handler in self.handlers.values():
if handler.enabled:
handler.finish_handler(node_id, entry, self.prompt_id)
def reset_handlers(self) -> None:
"""Reset all handlers"""
for handler in self.handlers.values():
handler.reset()
# Global registry instance
global_progress_registry: ProgressRegistry = ProgressRegistry(prompt_id="", dynprompt=DynamicPrompt({}))
def reset_progress_state(prompt_id: str, dynprompt: DynamicPrompt) -> None:
global global_progress_registry
# Reset existing handlers if registry exists
if global_progress_registry is not None:
global_progress_registry.reset_handlers()
# Create new registry
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
def add_progress_handler(handler: ProgressHandler) -> None:
handler.set_registry(global_progress_registry)
global_progress_registry.register_handler(handler)
def get_progress_state() -> ProgressRegistry:
return global_progress_registry

46
comfy_execution/utils.py Normal file
View File

@@ -0,0 +1,46 @@
import contextvars
from typing import Optional, NamedTuple
class ExecutionContext(NamedTuple):
"""
Context information about the currently executing node.
Attributes:
node_id: The ID of the currently executing node
list_index: The index in a list being processed (for operations on batches/lists)
"""
prompt_id: str
node_id: str
list_index: Optional[int]
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
def get_executing_context() -> Optional[ExecutionContext]:
return current_executing_context.get(None)
class CurrentNodeContext:
"""
Context manager for setting the current executing node context.
Sets the current_executing_context on enter and resets it on exit.
Example:
with CurrentNodeContext(node_id="123", list_index=0):
# Code that should run with the current node context set
process_image()
"""
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
self.context = ExecutionContext(
prompt_id= prompt_id,
node_id= node_id,
list_index= list_index
)
self.token = None
def __enter__(self):
self.token = current_executing_context.set(self.context)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.token is not None:
current_executing_context.reset(self.token)

View File

@@ -133,6 +133,14 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
if sample_rate != audio["sample_rate"]:
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
# Create in-memory WAV buffer
wav_buffer = io.BytesIO()
torchaudio.save(wav_buffer, waveform, sample_rate, format="WAV")
wav_buffer.seek(0) # Rewind for reading
# Use PyAV to convert and add metadata
input_container = av.open(wav_buffer)
# Create output with specified format
output_buffer = io.BytesIO()
output_container = av.open(output_buffer, mode='w', format=format)
@@ -142,6 +150,7 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
output_container.metadata[key] = value
# Set up the output stream with appropriate properties
input_container.streams.audio[0]
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate)
if quality == "64k":
@@ -166,16 +175,18 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
else: #format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate)
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
frame.sample_rate = sample_rate
frame.pts = 0
output_container.mux(out_stream.encode(frame))
# Copy frames from input to output
for frame in input_container.decode(audio=0):
frame.pts = None # Let PyAV handle timestamps
output_container.mux(out_stream.encode(frame))
# Flush encoder
output_container.mux(out_stream.encode(None))
# Close containers
output_container.close()
input_container.close()
# Write the output to file
output_buffer.seek(0)

View File

@@ -583,49 +583,6 @@ class GetImageSize:
return width, height, batch_size
class ImageRotate:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": (IO.IMAGE,),
"rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],),
}}
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "rotate"
CATEGORY = "image/transform"
def rotate(self, image, rotation):
rotate_by = 0
if rotation.startswith("90"):
rotate_by = 1
elif rotation.startswith("180"):
rotate_by = 2
elif rotation.startswith("270"):
rotate_by = 3
image = torch.rot90(image, k=rotate_by, dims=[2, 1])
return (image,)
class ImageFlip:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": (IO.IMAGE,),
"flip_method": (["x-axis: vertically", "y-axis: horizontally"],),
}}
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "flip"
CATEGORY = "image/transform"
def flip(self, image, flip_method):
if flip_method.startswith("x"):
image = torch.flip(image, dims=[1])
elif flip_method.startswith("y"):
image = torch.flip(image, dims=[2])
return (image,)
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
@@ -637,6 +594,4 @@ NODE_CLASS_MAPPINGS = {
"ImageStitch": ImageStitch,
"ResizeAndPadImage": ResizeAndPadImage,
"GetImageSize": GetImageSize,
"ImageRotate": ImageRotate,
"ImageFlip": ImageFlip,
}

View File

@@ -5,8 +5,6 @@ import os
from comfy.comfy_types import IO
from comfy_api.input_impl import VideoFromFile
from pathlib import Path
def normalize_path(path):
return path.replace('\\', '/')
@@ -18,14 +16,7 @@ class Load3D():
os.makedirs(input_dir, exist_ok=True)
input_path = Path(input_dir)
base_path = Path(folder_paths.get_input_directory())
files = [
normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
]
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.fbx', '.stl'))]
return {"required": {
"model_file": (sorted(files), {"file_upload": True}),
@@ -70,14 +61,7 @@ class Load3DAnimation():
os.makedirs(input_dir, exist_ok=True)
input_path = Path(input_dir)
base_path = Path(folder_paths.get_input_directory())
files = [
normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {'.gltf', '.glb', '.fbx'}
]
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.fbx'))]
return {"required": {
"model_file": (sorted(files), {"file_upload": True}),

View File

@@ -134,8 +134,8 @@ class LTXVAddGuide:
_, num_keyframes = get_keyframe_idxs(cond)
latent_count = latent_length - num_keyframes
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
if guide_length > 1 and frame_idx != 0:
frame_idx = (frame_idx - 1) // time_scale_factor * time_scale_factor + 1 # frame index - 1 must be divisible by 8 or frame_idx == 0
if guide_length > 1:
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
@@ -144,7 +144,7 @@ class LTXVAddGuide:
def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors):
keyframe_idxs, _ = get_keyframe_idxs(cond)
_, latent_coords = self._patchifier.patchify(guiding_latent)
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, True)
pixel_coords[:, 0] += frame_idx
if keyframe_idxs is None:
keyframe_idxs = pixel_coords

View File

@@ -152,7 +152,7 @@ class ImageColorToMask:
def image_to_mask(self, image, color):
temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
mask = torch.where(temp == color, 1.0, 0).float()
mask = torch.where(temp == color, 255, 0).float()
return (mask,)
class SolidMask:

View File

@@ -78,75 +78,7 @@ class SkipLayerGuidanceDiT:
return (m, )
class SkipLayerGuidanceDiTSimple:
'''
Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.
'''
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "skip_guidance"
EXPERIMENTAL = True
DESCRIPTION = "Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass."
CATEGORY = "advanced/guidance"
def skip_guidance(self, model, start_percent, end_percent, double_layers="", single_layers=""):
def skip(args, extra_args):
return args
model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)
double_layers = re.findall(r'\d+', double_layers)
double_layers = [int(i) for i in double_layers]
single_layers = re.findall(r'\d+', single_layers)
single_layers = [int(i) for i in single_layers]
if len(double_layers) == 0 and len(single_layers) == 0:
return (model, )
def calc_cond_batch_function(args):
x = args["input"]
model = args["model"]
conds = args["conds"]
sigma = args["sigma"]
model_options = args["model_options"]
slg_model_options = model_options.copy()
for layer in double_layers:
slg_model_options = comfy.model_patcher.set_model_options_patch_replace(slg_model_options, skip, "dit", "double_block", layer)
for layer in single_layers:
slg_model_options = comfy.model_patcher.set_model_options_patch_replace(slg_model_options, skip, "dit", "single_block", layer)
cond, uncond = conds
sigma_ = sigma[0].item()
if sigma_ >= sigma_end and sigma_ <= sigma_start and uncond is not None:
cond_out, _ = comfy.samplers.calc_cond_batch(model, [cond, None], x, sigma, model_options)
_, uncond_out = comfy.samplers.calc_cond_batch(model, [None, uncond], x, sigma, slg_model_options)
out = [cond_out, uncond_out]
else:
out = comfy.samplers.calc_cond_batch(model, conds, x, sigma, model_options)
return out
m = model.clone()
m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function)
return (m, )
NODE_CLASS_MAPPINGS = {
"SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
"SkipLayerGuidanceDiTSimple": SkipLayerGuidanceDiTSimple,
}

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.44"
__version__ = "0.3.43"

View File

@@ -8,12 +8,14 @@ import time
import traceback
from enum import Enum
from typing import List, Literal, NamedTuple, Optional
import asyncio
import torch
import comfy.model_management
import nodes
from comfy_execution.caching import (
BasicCache,
CacheKeySetID,
CacheKeySetInputSignature,
DependencyAwareCache,
@@ -28,6 +30,8 @@ from comfy_execution.graph import (
)
from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
from comfy_execution.utils import CurrentNodeContext
class ExecutionResult(Enum):
@@ -39,12 +43,13 @@ class DuplicateNodeError(Exception):
pass
class IsChangedCache:
def __init__(self, dynprompt, outputs_cache):
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache):
self.prompt_id = prompt_id
self.dynprompt = dynprompt
self.outputs_cache = outputs_cache
self.is_changed = {}
def get(self, node_id):
async def get(self, node_id):
if node_id in self.is_changed:
return self.is_changed[node_id]
@@ -62,7 +67,8 @@ class IsChangedCache:
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
try:
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, "IS_CHANGED")
is_changed = await resolve_map_node_over_list_results(is_changed)
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except Exception as e:
logging.warning("WARNING: {}".format(e))
@@ -164,7 +170,19 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
map_node_over_list = None #Don't hook this please
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
async def resolve_map_node_over_list_results(results):
remaining = [x for x in results if isinstance(x, asyncio.Task) and not x.done()]
if len(remaining) == 0:
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
else:
done, pending = await asyncio.wait(remaining)
for task in done:
exc = task.exception()
if exc is not None:
raise exc
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
@@ -178,7 +196,7 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
results = []
def process_inputs(inputs, index=None, input_is_list=False):
async def process_inputs(inputs, index=None, input_is_list=False):
if allow_interrupt:
nodes.before_node_execution()
execution_block = None
@@ -194,20 +212,37 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
if execution_block is None:
if pre_execute_cb is not None and index is not None:
pre_execute_cb(index)
results.append(getattr(obj, func)(**inputs))
f = getattr(obj, func)
if inspect.iscoroutinefunction(f):
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
with CurrentNodeContext(prompt_id, unique_id, list_index):
return await f(**args)
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
# Give the task a chance to execute without yielding
await asyncio.sleep(0)
if task.done():
result = task.result()
results.append(result)
else:
results.append(task)
else:
with CurrentNodeContext(prompt_id, unique_id, index):
result = f(**inputs)
results.append(result)
else:
results.append(execution_block)
if input_is_list:
process_inputs(input_data_all, 0, input_is_list=input_is_list)
await process_inputs(input_data_all, 0, input_is_list=input_is_list)
elif max_len_input == 0:
process_inputs({})
await process_inputs({})
else:
for i in range(max_len_input):
input_dict = slice_dict(input_data_all, i)
process_inputs(input_dict, i)
await process_inputs(input_dict, i)
return results
def merge_result_data(results, obj):
# check which outputs need concatenating
output = []
@@ -229,11 +264,18 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results])
return output
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task:
return return_values, {}, False, has_pending_task
output, ui, has_subgraph = get_output_from_returns(return_values, obj)
return output, ui, has_subgraph, False
def get_output_from_returns(return_values, obj):
results = []
uis = []
subgraph_results = []
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
has_subgraph = False
for i in range(len(return_values)):
r = return_values[i]
@@ -267,6 +309,10 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
else:
output = []
ui = dict()
# TODO: Think there's an existing bug here
# If we're performing a subgraph expansion, we probably shouldn't be returning UI values yet.
# They'll get cached without the completed subgraphs. It's an edge case and I'm not aware of
# any nodes that use both subgraph expansion and custom UI outputs, but might be a problem in the future.
if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui, has_subgraph
@@ -279,7 +325,7 @@ def format_value(x):
else:
return str(x)
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
@@ -291,11 +337,26 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
get_progress_state().finish_progress(unique_id)
return (ExecutionResult.SUCCESS, None, None)
input_data_all = None
try:
if unique_id in pending_subgraph_results:
if unique_id in pending_async_nodes:
results = []
for r in pending_async_nodes[unique_id]:
if isinstance(r, asyncio.Task):
try:
results.append(r.result())
except Exception as ex:
# An async task failed - propagate the exception up
del pending_async_nodes[unique_id]
raise ex
else:
results.append(r)
del pending_async_nodes[unique_id]
output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def)
elif unique_id in pending_subgraph_results:
cached_results = pending_subgraph_results[unique_id]
resolved_outputs = []
for is_subgraph, result in cached_results:
@@ -317,6 +378,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
output_ui = []
has_subgraph = False
else:
get_progress_state().start_progress(unique_id)
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
@@ -328,7 +390,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
caches.objects.set(unique_id, obj)
if hasattr(obj, "check_lazy_status"):
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True)
required_inputs = await resolve_map_node_over_list_results(required_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
x not in input_data_all or x in missing_keys
@@ -357,8 +420,18 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
else:
return block
def pre_execute_cb(call_index):
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id)
async def await_completion():
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
await asyncio.gather(*tasks, return_exceptions=True)
unblock()
asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0:
caches.ui.set(unique_id, {
"meta": {
@@ -401,7 +474,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
cached_outputs.append((True, node_outputs))
new_node_ids = set(new_node_ids)
for cache in caches.all:
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
subcache = await cache.ensure_subcache_for(unique_id, new_node_ids)
subcache.clean_unused()
for node_id in new_output_ids:
execution_list.add_node(node_id)
for link in new_output_links:
@@ -446,6 +520,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
return (ExecutionResult.FAILURE, error_details, ex)
get_progress_state().finish_progress(unique_id)
executed.add(unique_id)
return (ExecutionResult.SUCCESS, None, None)
@@ -500,6 +575,11 @@ class PromptExecutor:
self.add_message("execution_error", mes, broadcast=False)
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
asyncio_loop = asyncio.new_event_loop()
asyncio.set_event_loop(asyncio_loop)
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False)
if "client_id" in extra_data:
@@ -512,9 +592,11 @@ class PromptExecutor:
with torch.inference_mode():
dynamic_prompt = DynamicPrompt(prompt)
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server))
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
cache.clean_unused()
cached_nodes = []
@@ -527,6 +609,7 @@ class PromptExecutor:
{ "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
@@ -534,12 +617,13 @@ class PromptExecutor:
execution_list.add_node(node_id)
while not execution_list.is_empty():
node_id, error, ex = execution_list.stage_node_execution()
node_id, error, ex = await execution_list.stage_node_execution()
if error is not None:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
@@ -569,7 +653,7 @@ class PromptExecutor:
comfy.model_management.unload_all_models()
def validate_inputs(prompt, item, validated):
async def validate_inputs(prompt_id, prompt, item, validated):
unique_id = item
if unique_id in validated:
return validated[unique_id]
@@ -646,7 +730,7 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
try:
r = validate_inputs(prompt, o_id, validated)
r = await validate_inputs(prompt_id, prompt, o_id, validated)
if r[0] is False:
# `r` will be set in `validated[o_id]` already
valid = False
@@ -771,7 +855,8 @@ def validate_inputs(prompt, item, validated):
input_filtered['input_types'] = [received_types]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS")
ret = await resolve_map_node_over_list_results(ret)
for x in input_filtered:
for i, r in enumerate(ret):
if r is not True and not isinstance(r, ExecutionBlocker):
@@ -804,7 +889,7 @@ def full_type_name(klass):
return klass.__qualname__
return module + '.' + klass.__qualname__
def validate_prompt(prompt):
async def validate_prompt(prompt_id, prompt):
outputs = set()
for x in prompt:
if 'class_type' not in prompt[x]:
@@ -847,7 +932,7 @@ def validate_prompt(prompt):
valid = False
reasons = []
try:
m = validate_inputs(prompt, o, validated)
m = await validate_inputs(prompt_id, prompt, o, validated)
valid = m[0]
reasons = m[1]
except Exception as ex:

21
main.py
View File

@@ -11,6 +11,8 @@ import itertools
import utils.extra_config
import logging
import sys
from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
if __name__ == "__main__":
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
@@ -131,7 +133,7 @@ import comfy.utils
import execution
import server
from server import BinaryEventTypes
from protocol import BinaryEventTypes
import nodes
import comfy.model_management
import comfyui_version
@@ -227,14 +229,25 @@ async def run(server_instance, address='', port=8188, verbose=True, call_on_star
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
)
def hijack_progress(server_instance):
def hook(value, total, preview_image):
def hook(value, total, preview_image, prompt_id=None, node_id=None):
executing_context = get_executing_context()
if prompt_id is None and executing_context is not None:
prompt_id = executing_context.prompt_id
if node_id is None and executing_context is not None:
node_id = executing_context.node_id
comfy.model_management.throw_exception_if_processing_interrupted()
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
if prompt_id is None:
prompt_id = server_instance.last_prompt_id
if node_id is None:
node_id = server_instance.last_node_id
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
get_progress_state().update_progress(node_id, value, total, preview_image)
server_instance.send_sync("progress", progress, server_instance.client_id)
if preview_image is not None:
# Also send old method for backward compatibility
# TODO - Remove after this repo is updated to frontend with metadata support
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
comfy.utils.set_progress_bar_global_hook(hook)

View File

@@ -2310,7 +2310,6 @@ def init_builtin_api_nodes():
"nodes_pika.py",
"nodes_runway.py",
"nodes_tripo.py",
"nodes_moonvalley.py",
"nodes_rodin.py",
"nodes_gemini.py",
]

7
protocol.py Normal file
View File

@@ -0,0 +1,7 @@
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
TEXT = 3
PREVIEW_IMAGE_WITH_METADATA = 4

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.44"
version = "0.3.43"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@@ -1,6 +1,6 @@
comfyui-frontend-package==1.23.4
comfyui-workflow-templates==0.1.35
comfyui-embedded-docs==0.2.4
comfyui-workflow-templates==0.1.31
comfyui-embedded-docs==0.2.3
torch
torchsde
torchvision

View File

@@ -35,11 +35,7 @@ from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
TEXT = 3
from protocol import BinaryEventTypes
async def send_socket_catch_exception(function, message):
try:
@@ -643,7 +639,8 @@ class PromptServer():
if "prompt" in json_data:
prompt = json_data["prompt"]
valid = execution.validate_prompt(prompt)
prompt_id = str(uuid.uuid4())
valid = await execution.validate_prompt(prompt_id, prompt)
extra_data = {}
if "extra_data" in json_data:
extra_data = json_data["extra_data"]
@@ -651,7 +648,6 @@ class PromptServer():
if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"]
if valid[0]:
prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2]
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
@@ -766,6 +762,10 @@ class PromptServer():
async def send(self, event, data, sid=None):
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
await self.send_image(data, sid=sid)
elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA:
# data is (preview_image, metadata)
preview_image, metadata = data
await self.send_image_with_metadata(preview_image, metadata, sid=sid)
elif isinstance(data, (bytes, bytearray)):
await self.send_bytes(event, data, sid)
else:
@@ -804,6 +804,43 @@ class PromptServer():
preview_bytes = bytesIO.getvalue()
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
async def send_image_with_metadata(self, image_data, metadata=None, sid=None):
image_type = image_data[0]
image = image_data[1]
max_size = image_data[2]
if max_size is not None:
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.Resampling.LANCZOS
image = ImageOps.contain(image, (max_size, max_size), resampling)
mimetype = "image/png" if image_type == "PNG" else "image/jpeg"
# Prepare metadata
if metadata is None:
metadata = {}
metadata["image_type"] = mimetype
# Serialize metadata as JSON
import json
metadata_json = json.dumps(metadata).encode('utf-8')
metadata_length = len(metadata_json)
# Prepare image data
bytesIO = BytesIO()
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
image_bytes = bytesIO.getvalue()
# Combine metadata and image
combined_data = bytearray()
combined_data.extend(struct.pack(">I", metadata_length))
combined_data.extend(metadata_json)
combined_data.extend(image_bytes)
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, combined_data, sid=sid)
async def send_bytes(self, event, data, sid=None):
message = self.encode_bytes(event, data)

View File

@@ -1,3 +1,4 @@
pytest>=7.8.0
pytest-aiohttp
pytest-asyncio
websocket-client

View File

@@ -1,4 +1,4 @@
# Config for testing nodes
testing:
custom_nodes: tests/inference/testing_nodes
custom_nodes: testing_nodes

View File

@@ -0,0 +1,410 @@
import pytest
import time
import torch
import urllib.error
import numpy as np
import subprocess
from pytest import fixture
from comfy_execution.graph_utils import GraphBuilder
from tests.inference.test_execution import ComfyClient
@pytest.mark.execution
class TestAsyncNodes:
@fixture(scope="class", autouse=True, params=[
(False, 0),
(True, 0),
(True, 100),
])
def _server(self, args_pytest, request):
pargs = [
'python','main.py',
'--output-directory', args_pytest["output_dir"],
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
]
use_lru, lru_size = request.param
if use_lru:
pargs += ['--cache-lru', str(lru_size)]
# Running server with args: pargs
p = subprocess.Popen(pargs)
yield
p.kill()
torch.cuda.empty_cache()
@fixture(scope="class", autouse=True)
def shared_client(self, args_pytest, _server):
client = ComfyClient()
n_tries = 5
for i in range(n_tries):
time.sleep(4)
try:
client.connect(listen=args_pytest["listen"], port=args_pytest["port"])
except ConnectionRefusedError:
# Retrying...
pass
else:
break
yield client
del client
torch.cuda.empty_cache()
@fixture
def client(self, shared_client, request):
shared_client.set_test_name(f"async_nodes[{request.node.name}]")
yield shared_client
@fixture
def builder(self, request):
yield GraphBuilder(prefix=request.node.name)
# Happy Path Tests
def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test that a basic async node executes correctly."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1)
output = g.node("SaveImage", images=sleep_node.out(0))
result = client.run(g)
# Verify execution completed
assert result.did_run(sleep_node), "Async sleep node should have executed"
assert result.did_run(output), "Output node should have executed"
# Verify the image passed through correctly
result_images = result.get_images(output)
assert len(result_images) == 1, "Should have 1 image"
assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black"
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test that multiple async nodes execute in parallel."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create multiple async sleep nodes with different durations
sleep1 = g.node("TestSleep", value=image.out(0), seconds=0.3)
sleep2 = g.node("TestSleep", value=image.out(0), seconds=0.4)
sleep3 = g.node("TestSleep", value=image.out(0), seconds=0.5)
# Add outputs for each
_output1 = g.node("PreviewImage", images=sleep1.out(0))
_output2 = g.node("PreviewImage", images=sleep2.out(0))
_output3 = g.node("PreviewImage", images=sleep3.out(0))
start_time = time.time()
result = client.run(g)
elapsed_time = time.time() - start_time
# Should take ~0.5s (max duration) not 1.2s (sum of durations)
assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s"
# Verify all nodes executed
assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3)
def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes with proper dependency handling."""
g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Chain of async operations
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)
# Average depends on both async results
average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0))
output = g.node("SaveImage", images=average.out(0))
result = client.run(g)
# Verify execution order
assert result.did_run(sleep1) and result.did_run(sleep2)
assert result.did_run(average) and result.did_run(output)
# Verify averaged result
result_images = result.get_images(output)
avg_value = np.array(result_images[0]).mean()
assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5"
def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder):
"""Test async VALIDATE_INPUTS function."""
g = builder
# Create a test node with async validation
validation_node = g.node("TestAsyncValidation", value=5.0, threshold=10.0)
g.node("SaveImage", images=validation_node.out(0))
# Should pass validation
result = client.run(g)
assert result.did_run(validation_node)
# Test validation failure
validation_node.inputs['threshold'] = 3.0 # Will fail since value > threshold
with pytest.raises(urllib.error.HTTPError):
client.run(g)
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes with lazy evaluation."""
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1)
# Create async nodes that will be evaluated lazily
sleep1 = g.node("TestSleep", value=input1.out(0), seconds=0.3)
sleep2 = g.node("TestSleep", value=input2.out(0), seconds=0.3)
# Use lazy mix that only needs sleep1 (mask=0.0)
lazy_mix = g.node("TestLazyMixImages", image1=sleep1.out(0), image2=sleep2.out(0), mask=mask.out(0))
g.node("SaveImage", images=lazy_mix.out(0))
start_time = time.time()
result = client.run(g)
elapsed_time = time.time() - start_time
# Should only execute sleep1, not sleep2
assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s"
assert result.did_run(sleep1), "Sleep1 should have executed"
assert not result.did_run(sleep2), "Sleep2 should have been skipped"
def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder):
"""Test async check_lazy_status function."""
g = builder
# Create a node with async check_lazy_status
lazy_node = g.node("TestAsyncLazyCheck",
input1="value1",
input2="value2",
condition=True)
g.node("SaveImage", images=lazy_node.out(0))
result = client.run(g)
assert result.did_run(lazy_node)
# Error Handling Tests
def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder):
"""Test that async execution errors are properly handled."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create an async node that will error
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
g.node("SaveImage", images=error_node.out(0))
try:
client.run(g)
assert False, "Should have raised an error"
except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node"
def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder):
"""Test async validation error handling."""
g = builder
# Node with async validation that will fail
validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0)
g.node("SaveImage", images=validation_node.out(0))
with pytest.raises(urllib.error.HTTPError) as exc_info:
client.run(g)
# Verify it's a validation error
assert exc_info.value.code == 400
def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder):
"""Test handling of async operations that timeout."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Very long sleep that would timeout
timeout_node = g.node("TestAsyncTimeout", value=image.out(0), timeout=0.5, operation_time=2.0)
g.node("SaveImage", images=timeout_node.out(0))
try:
client.run(g)
assert False, "Should have raised a timeout error"
except Exception as e:
assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}"
def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder):
"""Test that workflow can recover after async errors."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# First run with error
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
g.node("SaveImage", images=error_node.out(0))
try:
client.run(g)
except Exception:
pass # Expected
# Second run should succeed
g2 = GraphBuilder(prefix="recovery_test")
image2 = g2.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1)
g2.node("SaveImage", images=sleep_node.out(0))
result = client.run(g2)
assert result.did_run(sleep_node), "Should be able to run after error"
def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test handling when sync node errors while async node is executing."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Async node that takes time
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.5)
# Sync node that will error immediately
error_node = g.node("TestSyncError", value=image.out(0))
# Both feed into output
g.node("PreviewImage", images=sleep_node.out(0))
g.node("PreviewImage", images=error_node.out(0))
try:
client.run(g)
assert False, "Should have raised an error"
except Exception as e:
# Verify the sync error was caught even though async was running
assert 'prompt_id' in e.args[0]
# Edge Cases
def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes with execution blockers."""
g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Async sleep nodes
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)
# Create list of images
image_list = g.node("TestMakeListNode", value1=sleep1.out(0), value2=sleep2.out(0))
# Create list of blocking conditions - [False, True] to block only the second item
int1 = g.node("StubInt", value=1)
int2 = g.node("StubInt", value=2)
block_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0))
# Compare each value against 2, so first is False (1 != 2) and second is True (2 == 2)
compare = g.node("TestIntConditions", a=block_list.out(0), b=2, operation="==")
# Block based on the comparison results
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
output = g.node("PreviewImage", images=blocker.out(0))
result = client.run(g)
images = result.get_images(output)
assert len(images) == 1, "Should have blocked second image"
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
"""Test that async nodes are properly cached."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
g.node("SaveImage", images=sleep_node.out(0))
# First run
result1 = client.run(g)
assert result1.did_run(sleep_node), "Should run first time"
# Second run - should be cached
start_time = time.time()
result2 = client.run(g)
elapsed_time = time.time() - start_time
assert not result2.did_run(sleep_node), "Should be cached"
assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant"
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes within dynamically generated prompts."""
g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Node that generates async nodes dynamically
dynamic_async = g.node("TestDynamicAsyncGeneration",
image1=image1.out(0),
image2=image2.out(0),
num_async_nodes=3,
sleep_duration=0.2)
g.node("SaveImage", images=dynamic_async.out(0))
start_time = time.time()
result = client.run(g)
elapsed_time = time.time() - start_time
# Should execute async nodes in parallel within dynamic prompt
assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s"
assert result.did_run(dynamic_async)
def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder):
"""Test that async resources are properly cleaned up."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create multiple async nodes that use resources
resource_nodes = []
for i in range(5):
node = g.node("TestAsyncResourceUser",
value=image.out(0),
resource_id=f"resource_{i}",
duration=0.1)
resource_nodes.append(node)
g.node("PreviewImage", images=node.out(0))
result = client.run(g)
# Verify all nodes executed
for node in resource_nodes:
assert result.did_run(node)
# Run again to ensure resources were cleaned up
result2 = client.run(g)
# Should be cached but not error due to resource conflicts
for node in resource_nodes:
assert not result2.did_run(node), "Should be cached"
def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder):
"""Test cancellation of async operations."""
# This would require implementing cancellation in the client
# For now, we'll test that long-running async operations can be interrupted
pass # TODO: Implement when cancellation API is available
def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test workflows with both sync and async nodes."""
g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
# Mix of sync and async operations
# Sync: lazy mix images
sync_op1 = g.node("TestLazyMixImages", image1=image1.out(0), image2=image2.out(0), mask=mask.out(0))
# Async: sleep
async_op1 = g.node("TestSleep", value=sync_op1.out(0), seconds=0.2)
# Sync: custom validation
sync_op2 = g.node("TestCustomValidation1", input1=async_op1.out(0), input2=0.5)
# Async: sleep again
async_op2 = g.node("TestSleep", value=sync_op2.out(0), seconds=0.2)
output = g.node("SaveImage", images=async_op2.out(0))
result = client.run(g)
# Verify all nodes executed in correct order
assert result.did_run(sync_op1)
assert result.did_run(async_op1)
assert result.did_run(sync_op2)
assert result.did_run(async_op2)
# Image should be a mix of black and white (gray)
result_images = result.get_images(output)
avg_value = np.array(result_images[0]).mean()
assert abs(avg_value - 63.75) < 5, f"Average value {avg_value} should be ~63.75"

View File

@@ -252,7 +252,7 @@ class TestExecution:
@pytest.mark.parametrize("test_type, test_value", [
("StubInt", 5),
("StubFloat", 5.0)
("StubMask", 5.0)
])
def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
g = builder
@@ -497,6 +497,69 @@ class TestExecution:
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
assert not result.did_run(test_node), "The execution should have been cached"
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create sleep nodes for each duration
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8)
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9)
sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0)
# Add outputs to verify the execution
_output1 = g.node("PreviewImage", images=sleep_node1.out(0))
_output2 = g.node("PreviewImage", images=sleep_node2.out(0))
_output3 = g.node("PreviewImage", images=sleep_node3.out(0))
start_time = time.time()
result = client.run(g)
elapsed_time = time.time() - start_time
# The test should take around 0.4 seconds (the longest sleep duration)
# plus some overhead, but definitely less than the sum of all sleeps (0.9s)
# We'll allow for up to 0.8s total to account for overhead
assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s"
# Verify that all nodes executed
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
assert result.did_run(sleep_node2), "Sleep node 2 should have run"
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
g = builder
# Create input images with different values
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
image3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Create a TestParallelSleep node that expands into multiple TestSleep nodes
parallel_sleep = g.node("TestParallelSleep",
image1=image1.out(0),
image2=image2.out(0),
image3=image3.out(0),
sleep1=0.4,
sleep2=0.5,
sleep3=0.6)
output = g.node("SaveImage", images=parallel_sleep.out(0))
start_time = time.time()
result = client.run(g)
elapsed_time = time.time() - start_time
# Similar to the previous test, expect parallel execution of the sleep nodes
# which should complete in less than the sum of all sleeps
assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s"
# Verify the parallel sleep node executed
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
# Verify we get an image as output (blend of the three input images)
result_images = result.get_images(output)
assert len(result_images) == 1, "Should have 1 image"
# Average pixel value should be around 170 (255 * 2 // 3)
avg_value = numpy.array(result_images[0]).mean()
assert avg_value == 170, f"Image average value {avg_value} should be 170"
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
# only that one entry in the list is blocked.

View File

@@ -3,6 +3,7 @@ from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DI
from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS
from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS
from .async_test_nodes import ASYNC_TEST_NODE_CLASS_MAPPINGS, ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
@@ -13,6 +14,7 @@ NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(ASYNC_TEST_NODE_CLASS_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS)
@@ -20,4 +22,5 @@ NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS)

View File

@@ -0,0 +1,343 @@
import torch
import asyncio
from typing import Dict
from comfy.utils import ProgressBar
from comfy_execution.graph_utils import GraphBuilder
from comfy.comfy_types.node_typing import ComfyNodeABC
from comfy.comfy_types import IO
class TestAsyncValidation(ComfyNodeABC):
"""Test node with async VALIDATE_INPUTS."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 5.0}),
"threshold": ("FLOAT", {"default": 10.0}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing/async"
@classmethod
async def VALIDATE_INPUTS(cls, value, threshold):
# Simulate async validation (e.g., checking remote service)
await asyncio.sleep(0.05)
if value > threshold:
return f"Value {value} exceeds threshold {threshold}"
return True
def process(self, value, threshold):
# Create image based on value
intensity = value / 10.0
image = torch.ones([1, 512, 512, 3]) * intensity
return (image,)
class TestAsyncError(ComfyNodeABC):
"""Test node that errors during async execution."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
"error_after": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 10.0}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "error_execution"
CATEGORY = "_for_testing/async"
async def error_execution(self, value, error_after):
await asyncio.sleep(error_after)
raise RuntimeError("Intentional async execution error for testing")
class TestAsyncValidationError(ComfyNodeABC):
"""Test node with async validation that always fails."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 5.0}),
"max_value": ("FLOAT", {"default": 10.0}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing/async"
@classmethod
async def VALIDATE_INPUTS(cls, value, max_value):
await asyncio.sleep(0.05)
# Always fail validation for values > max_value
if value > max_value:
return f"Async validation failed: {value} > {max_value}"
return True
def process(self, value, max_value):
# This won't be reached if validation fails
image = torch.ones([1, 512, 512, 3]) * (value / max_value)
return (image,)
class TestAsyncTimeout(ComfyNodeABC):
"""Test node that simulates timeout scenarios."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
"timeout": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0}),
"operation_time": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 10.0}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "timeout_execution"
CATEGORY = "_for_testing/async"
async def timeout_execution(self, value, timeout, operation_time):
try:
# This will timeout if operation_time > timeout
await asyncio.wait_for(asyncio.sleep(operation_time), timeout=timeout)
return (value,)
except asyncio.TimeoutError:
raise RuntimeError(f"Operation timed out after {timeout} seconds")
class TestSyncError(ComfyNodeABC):
"""Test node that errors synchronously (for mixed sync/async testing)."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "sync_error"
CATEGORY = "_for_testing/async"
def sync_error(self, value):
raise RuntimeError("Intentional sync execution error for testing")
class TestAsyncLazyCheck(ComfyNodeABC):
"""Test node with async check_lazy_status."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": (IO.ANY, {"lazy": True}),
"input2": (IO.ANY, {"lazy": True}),
"condition": ("BOOLEAN", {"default": True}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing/async"
async def check_lazy_status(self, condition, input1, input2):
# Simulate async checking (e.g., querying remote service)
await asyncio.sleep(0.05)
needed = []
if condition and input1 is None:
needed.append("input1")
if not condition and input2 is None:
needed.append("input2")
return needed
def process(self, input1, input2, condition):
# Return a simple image
return (torch.ones([1, 512, 512, 3]),)
class TestDynamicAsyncGeneration(ComfyNodeABC):
"""Test node that dynamically generates async nodes."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image1": ("IMAGE",),
"image2": ("IMAGE",),
"num_async_nodes": ("INT", {"default": 3, "min": 1, "max": 10}),
"sleep_duration": ("FLOAT", {"default": 0.2, "min": 0.1, "max": 1.0}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "generate_async_workflow"
CATEGORY = "_for_testing/async"
def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration):
g = GraphBuilder()
# Create multiple async sleep nodes
sleep_nodes = []
for i in range(num_async_nodes):
image = image1 if i % 2 == 0 else image2
sleep_node = g.node("TestSleep", value=image, seconds=sleep_duration)
sleep_nodes.append(sleep_node)
# Average all results
if len(sleep_nodes) == 1:
final_node = sleep_nodes[0]
else:
avg_inputs = {"input1": sleep_nodes[0].out(0)}
for i, node in enumerate(sleep_nodes[1:], 2):
avg_inputs[f"input{i}"] = node.out(0)
final_node = g.node("TestVariadicAverage", **avg_inputs)
return {
"result": (final_node.out(0),),
"expand": g.finalize(),
}
class TestAsyncResourceUser(ComfyNodeABC):
"""Test node that uses resources during async execution."""
# Class-level resource tracking for testing
_active_resources: Dict[str, bool] = {}
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
"resource_id": ("STRING", {"default": "resource_0"}),
"duration": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "use_resource"
CATEGORY = "_for_testing/async"
async def use_resource(self, value, resource_id, duration):
# Check if resource is already in use
if self._active_resources.get(resource_id, False):
raise RuntimeError(f"Resource {resource_id} is already in use!")
# Mark resource as in use
self._active_resources[resource_id] = True
try:
# Simulate resource usage
await asyncio.sleep(duration)
return (value,)
finally:
# Always clean up resource
self._active_resources[resource_id] = False
class TestAsyncBatchProcessing(ComfyNodeABC):
"""Test async processing of batched inputs."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"process_time_per_item": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 1.0}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process_batch"
CATEGORY = "_for_testing/async"
async def process_batch(self, images, process_time_per_item, unique_id):
batch_size = images.shape[0]
pbar = ProgressBar(batch_size, node_id=unique_id)
# Process each image in the batch
processed = []
for i in range(batch_size):
# Simulate async processing
await asyncio.sleep(process_time_per_item)
# Simple processing: invert the image
processed_image = 1.0 - images[i:i+1]
processed.append(processed_image)
pbar.update(1)
# Stack processed images
result = torch.cat(processed, dim=0)
return (result,)
class TestAsyncConcurrentLimit(ComfyNodeABC):
"""Test concurrent execution limits for async nodes."""
_semaphore = asyncio.Semaphore(2) # Only allow 2 concurrent executions
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
"duration": ("FLOAT", {"default": 0.5, "min": 0.1, "max": 2.0}),
"node_id": ("INT", {"default": 0}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "limited_execution"
CATEGORY = "_for_testing/async"
async def limited_execution(self, value, duration, node_id):
async with self._semaphore:
# Node {node_id} acquired semaphore
await asyncio.sleep(duration)
# Node {node_id} releasing semaphore
return (value,)
# Add node mappings
ASYNC_TEST_NODE_CLASS_MAPPINGS = {
"TestAsyncValidation": TestAsyncValidation,
"TestAsyncError": TestAsyncError,
"TestAsyncValidationError": TestAsyncValidationError,
"TestAsyncTimeout": TestAsyncTimeout,
"TestSyncError": TestSyncError,
"TestAsyncLazyCheck": TestAsyncLazyCheck,
"TestDynamicAsyncGeneration": TestDynamicAsyncGeneration,
"TestAsyncResourceUser": TestAsyncResourceUser,
"TestAsyncBatchProcessing": TestAsyncBatchProcessing,
"TestAsyncConcurrentLimit": TestAsyncConcurrentLimit,
}
ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestAsyncValidation": "Test Async Validation",
"TestAsyncError": "Test Async Error",
"TestAsyncValidationError": "Test Async Validation Error",
"TestAsyncTimeout": "Test Async Timeout",
"TestSyncError": "Test Sync Error",
"TestAsyncLazyCheck": "Test Async Lazy Check",
"TestDynamicAsyncGeneration": "Test Dynamic Async Generation",
"TestAsyncResourceUser": "Test Async Resource User",
"TestAsyncBatchProcessing": "Test Async Batch Processing",
"TestAsyncConcurrentLimit": "Test Async Concurrent Limit",
}

View File

@@ -1,6 +1,11 @@
import torch
import time
import asyncio
from comfy.utils import ProgressBar
from .tools import VariantSupport
from comfy_execution.graph_utils import GraphBuilder
from comfy.comfy_types.node_typing import ComfyNodeABC
from comfy.comfy_types import IO
class TestLazyMixImages:
@classmethod
@@ -333,6 +338,131 @@ class TestMixedExpansionReturns:
"expand": g.finalize(),
}
class TestSamplingInExpansion:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"clip": ("CLIP",),
"vae": ("VAE",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 100}),
"cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 30.0}),
"prompt": ("STRING", {"multiline": True, "default": "a beautiful landscape with mountains and trees"}),
"negative_prompt": ("STRING", {"multiline": True, "default": "blurry, bad quality, worst quality"}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "sampling_in_expansion"
CATEGORY = "Testing/Nodes"
def sampling_in_expansion(self, model, clip, vae, seed, steps, cfg, prompt, negative_prompt):
g = GraphBuilder()
# Create a basic image generation workflow using the input model, clip and vae
# 1. Setup text prompts using the provided CLIP model
positive_prompt = g.node("CLIPTextEncode",
text=prompt,
clip=clip)
negative_prompt = g.node("CLIPTextEncode",
text=negative_prompt,
clip=clip)
# 2. Create empty latent with specified size
empty_latent = g.node("EmptyLatentImage", width=512, height=512, batch_size=1)
# 3. Setup sampler and generate image latent
sampler = g.node("KSampler",
model=model,
positive=positive_prompt.out(0),
negative=negative_prompt.out(0),
latent_image=empty_latent.out(0),
seed=seed,
steps=steps,
cfg=cfg,
sampler_name="euler_ancestral",
scheduler="normal")
# 4. Decode latent to image using VAE
output = g.node("VAEDecode", samples=sampler.out(0), vae=vae)
return {
"result": (output.out(0),),
"expand": g.finalize(),
}
class TestSleep(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
"seconds": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 9999.0, "step": 0.01, "tooltip": "The amount of seconds to sleep."}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "sleep"
CATEGORY = "_for_testing"
async def sleep(self, value, seconds, unique_id):
pbar = ProgressBar(seconds, node_id=unique_id)
start = time.time()
expiration = start + seconds
now = start
while now < expiration:
now = time.time()
pbar.update_absolute(now - start)
await asyncio.sleep(0.01)
return (value,)
class TestParallelSleep(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image1": ("IMAGE", ),
"image2": ("IMAGE", ),
"image3": ("IMAGE", ),
"sleep1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
"sleep2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
"sleep3": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "parallel_sleep"
CATEGORY = "_for_testing"
OUTPUT_NODE = True
def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id):
# Create a graph dynamically with three TestSleep nodes
g = GraphBuilder()
# Create sleep nodes for each duration and image
sleep_node1 = g.node("TestSleep", value=image1, seconds=sleep1)
sleep_node2 = g.node("TestSleep", value=image2, seconds=sleep2)
sleep_node3 = g.node("TestSleep", value=image3, seconds=sleep3)
# Blend the results using TestVariadicAverage
blend = g.node("TestVariadicAverage",
input1=sleep_node1.out(0),
input2=sleep_node2.out(0),
input3=sleep_node3.out(0))
return {
"result": (blend.out(0),),
"expand": g.finalize(),
}
TEST_NODE_CLASS_MAPPINGS = {
"TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage,
@@ -345,6 +475,9 @@ TEST_NODE_CLASS_MAPPINGS = {
"TestCustomValidation5": TestCustomValidation5,
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
"TestMixedExpansionReturns": TestMixedExpansionReturns,
"TestSamplingInExpansion": TestSamplingInExpansion,
"TestSleep": TestSleep,
"TestParallelSleep": TestParallelSleep,
}
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
@@ -359,4 +492,7 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestCustomValidation5": "Custom Validation 5",
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
"TestMixedExpansionReturns": "Mixed Expansion Returns",
"TestSamplingInExpansion": "Sampling In Expansion",
"TestSleep": "Test Sleep",
"TestParallelSleep": "Test Parallel Sleep",
}