mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 05:00:03 +00:00
Compare commits
5 Commits
v0.3.44
...
js/drafts/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0254d9cc11 | ||
|
|
92f9a10782 | ||
|
|
a6a6b615f4 | ||
|
|
50bf72f852 | ||
|
|
46c8311d14 |
4
.github/workflows/test-unit.yml
vendored
4
.github/workflows/test-unit.yml
vendored
@@ -28,3 +28,7 @@ jobs:
|
||||
run: |
|
||||
pip install -r tests-unit/requirements.txt
|
||||
python -m pytest tests-unit
|
||||
- name: Run Execution Model Tests
|
||||
run: |
|
||||
python -m pytest tests/inference/test_execution.py
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "129"
|
||||
default: "128"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@@ -19,7 +19,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "5"
|
||||
default: "2"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@@ -53,8 +53,6 @@ jobs:
|
||||
ls ../temp_wheel_dir
|
||||
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
|
||||
@@ -86,7 +86,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||
- Works even if you don't have a GPU with: ```--cpu``` (slow)
|
||||
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
|
||||
- Safe loading of ckpt, pt, pth, etc.. files.
|
||||
- Embeddings/Textual inversion
|
||||
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
|
||||
@@ -102,6 +101,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Starts up very fast.
|
||||
- Works fully offline: core will never download anything unless you want to.
|
||||
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
|
||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||
@@ -243,7 +243,7 @@ Nvidia users should install stable pytorch using this command:
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
|
||||
|
||||
#### Troubleshooting
|
||||
|
||||
|
||||
@@ -1,10 +1,55 @@
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from .ldm.modules.attention import CrossAttention, FeedForward
|
||||
from .ldm.modules.attention import CrossAttention
|
||||
from inspect import isfunction
|
||||
import comfy.ops
|
||||
ops = comfy.ops.manual_cast
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = ops.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * torch.nn.functional.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
ops.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
ops.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class GatedCrossAttentionDense(nn.Module):
|
||||
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||
|
||||
@@ -412,13 +412,9 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
|
||||
ds.pop(0)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
cur_order = min(i + 1, order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
cur_order = min(i + 1, order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
return x
|
||||
|
||||
|
||||
@@ -1071,9 +1067,7 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
d_cur = (x_cur - denoised) / t_cur
|
||||
|
||||
order = min(max_order, i+1)
|
||||
if t_next == 0: # Denoising step
|
||||
x_next = denoised
|
||||
elif order == 1: # First Euler step.
|
||||
if order == 1: # First Euler step.
|
||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||
elif order == 2: # Use one history point.
|
||||
x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
|
||||
@@ -1091,7 +1085,6 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
|
||||
return x_next
|
||||
|
||||
|
||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
#under Apache 2 license
|
||||
def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
|
||||
@@ -1115,9 +1108,7 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
d_cur = (x_cur - denoised) / t_cur
|
||||
|
||||
order = min(max_order, i+1)
|
||||
if t_next == 0: # Denoising step
|
||||
x_next = denoised
|
||||
elif order == 1: # First Euler step.
|
||||
if order == 1: # First Euler step.
|
||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||
elif order == 2: # Use one history point.
|
||||
h_n = (t_next - t_cur)
|
||||
@@ -1157,7 +1148,6 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
|
||||
return x_next
|
||||
|
||||
|
||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
#under Apache 2 license
|
||||
@torch.no_grad()
|
||||
@@ -1208,7 +1198,6 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
|
||||
return x_next
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
@@ -1415,7 +1404,6 @@ def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=N
|
||||
def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
|
||||
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
|
||||
@@ -1442,19 +1430,19 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
if i == 0:
|
||||
# Euler method
|
||||
if cfg_pp:
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
else:
|
||||
x = x + d * dt
|
||||
|
||||
if i >= 1:
|
||||
# Gradient estimation
|
||||
else:
|
||||
# Gradient estimation
|
||||
if cfg_pp:
|
||||
d_bar = (ge_gamma - 1) * (d - old_d)
|
||||
x = denoised + d * sigmas[i + 1] + d_bar * dt
|
||||
else:
|
||||
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
|
||||
x = x + d_bar * dt
|
||||
old_d = d
|
||||
return x
|
||||
|
||||
@@ -379,9 +379,6 @@ class ModelPatcher:
|
||||
def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False):
|
||||
self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization)
|
||||
|
||||
def set_model_sampler_calc_cond_batch_function(self, sampler_calc_cond_batch_function):
|
||||
self.model_options["sampler_calc_cond_batch_function"] = sampler_calc_cond_batch_function
|
||||
|
||||
def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
|
||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||
|
||||
|
||||
@@ -336,12 +336,9 @@ class fp8_ops(manual_cast):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
try:
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
return out
|
||||
except Exception as e:
|
||||
logging.info("Exception during fp8 op: {}".format(e))
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
@@ -373,11 +373,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
||||
uncond_ = uncond
|
||||
|
||||
conds = [cond, uncond_]
|
||||
if "sampler_calc_cond_batch_function" in model_options:
|
||||
args = {"conds": conds, "input": x, "sigma": timestep, "model": model, "model_options": model_options}
|
||||
out = model_options["sampler_calc_cond_batch_function"](args)
|
||||
else:
|
||||
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
||||
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
||||
|
||||
for fn in model_options.get("sampler_pre_cfg_function", []):
|
||||
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
|
||||
|
||||
@@ -77,7 +77,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if safe_load or ALWAYS_SAFE_LOAD:
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||
else:
|
||||
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
|
||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||
if "state_dict" in pl_sd:
|
||||
sd = pl_sd["state_dict"]
|
||||
@@ -998,11 +997,12 @@ def set_progress_bar_global_hook(function):
|
||||
PROGRESS_BAR_HOOK = function
|
||||
|
||||
class ProgressBar:
|
||||
def __init__(self, total):
|
||||
def __init__(self, total, node_id=None):
|
||||
global PROGRESS_BAR_HOOK
|
||||
self.total = total
|
||||
self.current = 0
|
||||
self.hook = PROGRESS_BAR_HOOK
|
||||
self.node_id = node_id
|
||||
|
||||
def update_absolute(self, value, total=None, preview=None):
|
||||
if total is not None:
|
||||
@@ -1011,7 +1011,7 @@ class ProgressBar:
|
||||
value = self.total
|
||||
self.current = value
|
||||
if self.hook is not None:
|
||||
self.hook(self.current, self.total, preview)
|
||||
self.hook(self.current, self.total, preview, node_id=self.node_id)
|
||||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
import io
|
||||
from typing import Optional
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
class VideoInput(ABC):
|
||||
@@ -32,22 +31,6 @@ class VideoInput(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
||||
"""
|
||||
Get a streamable source for the video. This allows processing without
|
||||
loading the entire video into memory.
|
||||
|
||||
Returns:
|
||||
Either a file path (str) or a BytesIO object that can be opened with av.
|
||||
|
||||
Default implementation creates a BytesIO buffer, but subclasses should
|
||||
override this for better performance when possible.
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
self.save_to(buffer)
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
# Provide a default implementation, but subclasses can provide optimized versions
|
||||
# if possible.
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
|
||||
@@ -64,15 +64,6 @@ class VideoFromFile(VideoInput):
|
||||
"""
|
||||
self.__file = file
|
||||
|
||||
def get_stream_source(self) -> str | io.BytesIO:
|
||||
"""
|
||||
Return the underlying file source for efficient streaming.
|
||||
This avoids unnecessary memory copies when the source is already a file path.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
return self.__file
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the dimensions of the video input.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: filtered-openapi.yaml
|
||||
# timestamp: 2025-07-06T09:47:31+00:00
|
||||
# timestamp: 2025-05-19T21:38:55+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -1355,158 +1355,6 @@ class ModelResponseProperties(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class Keyframes(BaseModel):
|
||||
image_url: Optional[str] = None
|
||||
|
||||
|
||||
class MoonvalleyPromptResponse(BaseModel):
|
||||
error: Optional[Dict[str, Any]] = None
|
||||
frame_conditioning: Optional[Dict[str, Any]] = None
|
||||
id: Optional[str] = None
|
||||
inference_params: Optional[Dict[str, Any]] = None
|
||||
meta: Optional[Dict[str, Any]] = None
|
||||
model_params: Optional[Dict[str, Any]] = None
|
||||
output_url: Optional[str] = None
|
||||
prompt_text: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
|
||||
|
||||
class MoonvalleyTextToVideoInferenceParams(BaseModel):
|
||||
add_quality_guidance: Optional[bool] = Field(
|
||||
True, description='Whether to add quality guidance'
|
||||
)
|
||||
caching_coefficient: Optional[float] = Field(
|
||||
0.3, description='Caching coefficient for optimization'
|
||||
)
|
||||
caching_cooldown: Optional[int] = Field(
|
||||
3, description='Number of caching cooldown steps'
|
||||
)
|
||||
caching_warmup: Optional[int] = Field(
|
||||
3, description='Number of caching warmup steps'
|
||||
)
|
||||
clip_value: Optional[float] = Field(
|
||||
3, description='CLIP value for generation control'
|
||||
)
|
||||
conditioning_frame_index: Optional[int] = Field(
|
||||
0, description='Index of the conditioning frame'
|
||||
)
|
||||
cooldown_steps: Optional[int] = Field(
|
||||
None, description='Number of cooldown steps (calculated based on num_frames)'
|
||||
)
|
||||
fps: Optional[int] = Field(
|
||||
24, description='Frames per second of the generated video'
|
||||
)
|
||||
guidance_scale: Optional[float] = Field(
|
||||
12.5, description='Guidance scale for generation control'
|
||||
)
|
||||
height: Optional[int] = Field(
|
||||
1080, description='Height of the generated video in pixels'
|
||||
)
|
||||
negative_prompt: Optional[str] = Field(None, description='Negative prompt text')
|
||||
num_frames: Optional[int] = Field(64, description='Number of frames to generate')
|
||||
seed: Optional[int] = Field(
|
||||
None, description='Random seed for generation (default: random)'
|
||||
)
|
||||
shift_value: Optional[float] = Field(
|
||||
3, description='Shift value for generation control'
|
||||
)
|
||||
steps: Optional[int] = Field(80, description='Number of denoising steps')
|
||||
use_guidance_schedule: Optional[bool] = Field(
|
||||
True, description='Whether to use guidance scheduling'
|
||||
)
|
||||
use_negative_prompts: Optional[bool] = Field(
|
||||
False, description='Whether to use negative prompts'
|
||||
)
|
||||
use_timestep_transform: Optional[bool] = Field(
|
||||
True, description='Whether to use timestep transformation'
|
||||
)
|
||||
warmup_steps: Optional[int] = Field(
|
||||
None, description='Number of warmup steps (calculated based on num_frames)'
|
||||
)
|
||||
width: Optional[int] = Field(
|
||||
1920, description='Width of the generated video in pixels'
|
||||
)
|
||||
|
||||
|
||||
class MoonvalleyTextToVideoRequest(BaseModel):
|
||||
image_url: Optional[str] = None
|
||||
inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None
|
||||
prompt_text: Optional[str] = None
|
||||
webhook_url: Optional[str] = None
|
||||
|
||||
|
||||
class MoonvalleyUploadFileRequest(BaseModel):
|
||||
file: Optional[StrictBytes] = None
|
||||
|
||||
|
||||
class MoonvalleyUploadFileResponse(BaseModel):
|
||||
access_url: Optional[str] = None
|
||||
|
||||
|
||||
class MoonvalleyVideoToVideoInferenceParams(BaseModel):
|
||||
add_quality_guidance: Optional[bool] = Field(
|
||||
True, description='Whether to add quality guidance'
|
||||
)
|
||||
caching_coefficient: Optional[float] = Field(
|
||||
0.3, description='Caching coefficient for optimization'
|
||||
)
|
||||
caching_cooldown: Optional[int] = Field(
|
||||
3, description='Number of caching cooldown steps'
|
||||
)
|
||||
caching_warmup: Optional[int] = Field(
|
||||
3, description='Number of caching warmup steps'
|
||||
)
|
||||
clip_value: Optional[float] = Field(
|
||||
3, description='CLIP value for generation control'
|
||||
)
|
||||
conditioning_frame_index: Optional[int] = Field(
|
||||
0, description='Index of the conditioning frame'
|
||||
)
|
||||
cooldown_steps: Optional[int] = Field(
|
||||
None, description='Number of cooldown steps (calculated based on num_frames)'
|
||||
)
|
||||
guidance_scale: Optional[float] = Field(
|
||||
12.5, description='Guidance scale for generation control'
|
||||
)
|
||||
negative_prompt: Optional[str] = Field(None, description='Negative prompt text')
|
||||
seed: Optional[int] = Field(
|
||||
None, description='Random seed for generation (default: random)'
|
||||
)
|
||||
shift_value: Optional[float] = Field(
|
||||
3, description='Shift value for generation control'
|
||||
)
|
||||
steps: Optional[int] = Field(80, description='Number of denoising steps')
|
||||
use_guidance_schedule: Optional[bool] = Field(
|
||||
True, description='Whether to use guidance scheduling'
|
||||
)
|
||||
use_negative_prompts: Optional[bool] = Field(
|
||||
False, description='Whether to use negative prompts'
|
||||
)
|
||||
use_timestep_transform: Optional[bool] = Field(
|
||||
True, description='Whether to use timestep transformation'
|
||||
)
|
||||
warmup_steps: Optional[int] = Field(
|
||||
None, description='Number of warmup steps (calculated based on num_frames)'
|
||||
)
|
||||
|
||||
|
||||
class ControlType(str, Enum):
|
||||
motion_control = 'motion_control'
|
||||
pose_control = 'pose_control'
|
||||
|
||||
|
||||
class MoonvalleyVideoToVideoRequest(BaseModel):
|
||||
control_type: ControlType = Field(
|
||||
..., description='Supported types for video control'
|
||||
)
|
||||
inference_params: Optional[MoonvalleyVideoToVideoInferenceParams] = None
|
||||
prompt_text: str = Field(..., description='Describes the video to generate')
|
||||
video_url: str = Field(..., description='Url to control video')
|
||||
webhook_url: Optional[str] = Field(
|
||||
None, description='Optional webhook URL for notifications'
|
||||
)
|
||||
|
||||
|
||||
class Moderation(str, Enum):
|
||||
low = 'low'
|
||||
auto = 'auto'
|
||||
@@ -3259,23 +3107,6 @@ class LumaUpscaleVideoGenerationRequest(BaseModel):
|
||||
resolution: Optional[LumaVideoModelOutputResolution] = None
|
||||
|
||||
|
||||
class MoonvalleyImageToVideoRequest(MoonvalleyTextToVideoRequest):
|
||||
keyframes: Optional[Dict[str, Keyframes]] = None
|
||||
|
||||
|
||||
class MoonvalleyResizeVideoRequest(MoonvalleyVideoToVideoRequest):
|
||||
frame_position: Optional[List[int]] = Field(None, max_length=2, min_length=2)
|
||||
frame_resolution: Optional[List[int]] = Field(None, max_length=2, min_length=2)
|
||||
scale: Optional[List[int]] = Field(None, max_length=2, min_length=2)
|
||||
|
||||
|
||||
class MoonvalleyTextToImageRequest(BaseModel):
|
||||
image_url: Optional[str] = None
|
||||
inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None
|
||||
prompt_text: Optional[str] = None
|
||||
webhook_url: Optional[str] = None
|
||||
|
||||
|
||||
class OutputContent(RootModel[Union[OutputTextContent, OutputAudioContent]]):
|
||||
root: Union[OutputTextContent, OutputAudioContent]
|
||||
|
||||
|
||||
@@ -1,639 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
import random
|
||||
import torch
|
||||
from comfy_api_nodes.util.validation_utils import get_image_dimensions, validate_image_dimensions, validate_video_dimensions
|
||||
|
||||
|
||||
from comfy_api_nodes.apis import (
|
||||
MoonvalleyTextToVideoRequest,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
MoonvalleyVideoToVideoInferenceParams,
|
||||
MoonvalleyVideoToVideoRequest,
|
||||
MoonvalleyPromptResponse
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_video_output,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
)
|
||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
import av
|
||||
import io
|
||||
|
||||
API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads"
|
||||
API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts"
|
||||
API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video"
|
||||
API_TXT2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/text-to-video"
|
||||
API_IMG2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/image-to-video"
|
||||
|
||||
MIN_WIDTH = 300
|
||||
MIN_HEIGHT = 300
|
||||
|
||||
MAX_WIDTH = 10000
|
||||
MAX_HEIGHT = 10000
|
||||
|
||||
MIN_VID_WIDTH = 300
|
||||
MIN_VID_HEIGHT = 300
|
||||
|
||||
MAX_VID_WIDTH = 10000
|
||||
MAX_VID_HEIGHT = 10000
|
||||
|
||||
MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing
|
||||
|
||||
MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000
|
||||
R = TypeVar("R")
|
||||
class MoonvalleyApiError(Exception):
|
||||
"""Base exception for Moonvalley API errors."""
|
||||
pass
|
||||
|
||||
def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool:
|
||||
"""Verifies that the initial response contains a task ID."""
|
||||
return bool(response.id)
|
||||
|
||||
def validate_task_creation_response(response) -> None:
|
||||
if not is_valid_task_creation_response(response):
|
||||
error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}"
|
||||
logging.error(error_msg)
|
||||
raise MoonvalleyApiError(error_msg)
|
||||
|
||||
def get_video_from_response(response):
|
||||
video = response.output_url
|
||||
logging.info(
|
||||
"Moonvalley Marey API: Task %s succeeded. Video URL: %s", response.id, video
|
||||
)
|
||||
return video
|
||||
|
||||
|
||||
def get_video_url_from_response(response) -> Optional[str]:
|
||||
"""Returns the first video url from the Moonvalley video generation task result.
|
||||
Will not raise an error if the response is not valid.
|
||||
"""
|
||||
if response:
|
||||
return str(get_video_from_response(response))
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
"completed",
|
||||
],
|
||||
max_poll_attempts=240, # 64 minutes with 16s interval
|
||||
poll_interval=16.0,
|
||||
failed_statuses=["error"],
|
||||
status_extractor=lambda response: (
|
||||
response.status
|
||||
if response and response.status
|
||||
else None
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=result_url_extractor,
|
||||
node_id=node_id,
|
||||
).execute()
|
||||
|
||||
def validate_prompts(prompt:str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH):
|
||||
"""Verifies that the prompt isn't empty and that neither prompt is too long."""
|
||||
if not prompt:
|
||||
raise ValueError("Positive prompt is empty")
|
||||
if len(prompt) > max_length:
|
||||
raise ValueError(f"Positive prompt is too long: {len(prompt)} characters")
|
||||
if negative_prompt and len(negative_prompt) > max_length:
|
||||
raise ValueError(
|
||||
f"Negative prompt is too long: {len(negative_prompt)} characters"
|
||||
)
|
||||
return True
|
||||
|
||||
def validate_input_media(width, height, with_frame_conditioning, num_frames_in=None):
|
||||
# inference validation
|
||||
# T = num_frames
|
||||
# in all cases, the following must be true: T divisible by 16 and H,W by 8. in addition...
|
||||
# with image conditioning: H*W must be divisible by 8192
|
||||
# without image conditioning: T divisible by 32
|
||||
if num_frames_in and not num_frames_in % 16 == 0 :
|
||||
return False, (
|
||||
"The input video total frame count must be divisible by 16!"
|
||||
)
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
return False, (
|
||||
f"Height ({height}) and width ({width}) must be " "divisible by 8"
|
||||
)
|
||||
|
||||
if with_frame_conditioning:
|
||||
if (height * width) % 8192 != 0:
|
||||
return False, (
|
||||
f"Height * width ({height * width}) must be "
|
||||
"divisible by 8192 for frame conditioning"
|
||||
)
|
||||
else:
|
||||
if num_frames_in and not num_frames_in % 32 == 0 :
|
||||
return False, (
|
||||
"The input video total frame count must be divisible by 32!"
|
||||
)
|
||||
|
||||
|
||||
def validate_input_image(image: torch.Tensor, with_frame_conditioning: bool=False) -> None:
|
||||
"""
|
||||
Validates the input image adheres to the expectations of the API:
|
||||
- The image resolution should not be less than 300*300px
|
||||
- The aspect ratio of the image should be between 1:2.5 ~ 2.5:1
|
||||
|
||||
"""
|
||||
height, width = get_image_dimensions(image)
|
||||
validate_input_media(width, height, with_frame_conditioning )
|
||||
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
|
||||
|
||||
def validate_input_video(video: VideoInput, num_frames_out: int, with_frame_conditioning: bool=False):
|
||||
try:
|
||||
width, height = video.get_dimensions()
|
||||
except Exception as e:
|
||||
logging.error("Error getting dimensions of video: %s", e)
|
||||
raise ValueError(f"Cannot get video dimensions: {e}") from e
|
||||
|
||||
validate_input_media(width, height, with_frame_conditioning)
|
||||
validate_video_dimensions(video, min_width=MIN_VID_WIDTH, min_height=MIN_VID_HEIGHT, max_width=MAX_VID_WIDTH, max_height=MAX_VID_HEIGHT)
|
||||
|
||||
trimmed_video = validate_input_video_length(video, num_frames_out)
|
||||
return trimmed_video
|
||||
|
||||
|
||||
def validate_input_video_length(video: VideoInput, num_frames: int):
|
||||
|
||||
if video.get_duration() > 60:
|
||||
raise MoonvalleyApiError("Input Video lenth should be less than 1min. Please trim.")
|
||||
|
||||
if num_frames == 128:
|
||||
if video.get_duration() < 5:
|
||||
raise MoonvalleyApiError("Input Video length is less than 5s. Please use a video longer than or equal to 5s.")
|
||||
if video.get_duration() > 5:
|
||||
# trim video to 5s
|
||||
video = trim_video(video, 5)
|
||||
if num_frames == 256:
|
||||
if video.get_duration() < 10:
|
||||
raise MoonvalleyApiError("Input Video length is less than 10s. Please use a video longer than or equal to 10s.")
|
||||
if video.get_duration() > 10:
|
||||
# trim video to 10s
|
||||
video = trim_video(video, 10)
|
||||
return video
|
||||
|
||||
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
||||
"""
|
||||
Returns a new VideoInput object trimmed from the beginning to the specified duration,
|
||||
using av to avoid loading entire video into memory.
|
||||
|
||||
Args:
|
||||
video: Input video to trim
|
||||
duration_sec: Duration in seconds to keep from the beginning
|
||||
|
||||
Returns:
|
||||
VideoFromFile object that owns the output buffer
|
||||
"""
|
||||
output_buffer = io.BytesIO()
|
||||
|
||||
input_container = None
|
||||
output_container = None
|
||||
|
||||
try:
|
||||
# Get the stream source - this avoids loading entire video into memory
|
||||
# when the source is already a file path
|
||||
input_source = video.get_stream_source()
|
||||
|
||||
# Open containers
|
||||
input_container = av.open(input_source, mode='r')
|
||||
output_container = av.open(output_buffer, mode='w', format='mp4')
|
||||
|
||||
# Set up output streams for re-encoding
|
||||
video_stream = None
|
||||
audio_stream = None
|
||||
|
||||
for stream in input_container.streams:
|
||||
logging.info(f"Found stream: type={stream.type}, class={type(stream)}")
|
||||
if isinstance(stream, av.VideoStream):
|
||||
# Create output video stream with same parameters
|
||||
video_stream = output_container.add_stream('h264', rate=stream.average_rate)
|
||||
video_stream.width = stream.width
|
||||
video_stream.height = stream.height
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
logging.info(f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps")
|
||||
elif isinstance(stream, av.AudioStream):
|
||||
# Create output audio stream with same parameters
|
||||
audio_stream = output_container.add_stream('aac', rate=stream.sample_rate)
|
||||
audio_stream.sample_rate = stream.sample_rate
|
||||
audio_stream.layout = stream.layout
|
||||
logging.info(f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels")
|
||||
|
||||
# Calculate target frame count that's divisible by 32
|
||||
fps = input_container.streams.video[0].average_rate
|
||||
estimated_frames = int(duration_sec * fps)
|
||||
target_frames = (estimated_frames // 32) * 32 # Round down to nearest multiple of 32
|
||||
|
||||
if target_frames == 0:
|
||||
raise ValueError("Video too short: need at least 32 frames for Moonvalley")
|
||||
|
||||
frame_count = 0
|
||||
audio_frame_count = 0
|
||||
|
||||
# Decode and re-encode video frames
|
||||
if video_stream:
|
||||
for frame in input_container.decode(video=0):
|
||||
if frame_count >= target_frames:
|
||||
break
|
||||
|
||||
# Re-encode frame
|
||||
for packet in video_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
for packet in video_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info(f"Encoded {frame_count} video frames (target: {target_frames})")
|
||||
|
||||
# Decode and re-encode audio frames
|
||||
if audio_stream:
|
||||
input_container.seek(0) # Reset to beginning for audio
|
||||
for frame in input_container.decode(audio=0):
|
||||
if frame.time >= duration_sec:
|
||||
break
|
||||
|
||||
# Re-encode frame
|
||||
for packet in audio_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
audio_frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
for packet in audio_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info(f"Encoded {audio_frame_count} audio frames")
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
input_container.close()
|
||||
|
||||
|
||||
# Return as VideoFromFile using the buffer
|
||||
output_buffer.seek(0)
|
||||
return VideoFromFile(output_buffer)
|
||||
|
||||
except Exception as e:
|
||||
# Clean up on error
|
||||
if input_container is not None:
|
||||
input_container.close()
|
||||
if output_container is not None:
|
||||
output_container.close()
|
||||
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
||||
|
||||
# --- BaseMoonvalleyVideoNode ---
|
||||
class BaseMoonvalleyVideoNode:
|
||||
def parseWidthHeightFromRes(self, resolution: str):
|
||||
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
|
||||
res_map = {
|
||||
"16:9 (1920 x 1080)": {"width": 1920, "height": 1080},
|
||||
"9:16 (1080 x 1920)": {"width": 1080, "height": 1920},
|
||||
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
|
||||
"4:3 (1440 x 1080)": {"width": 1440, "height": 1080},
|
||||
"3:4 (1080 x 1440)": {"width": 1080, "height": 1440},
|
||||
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
|
||||
}
|
||||
if resolution in res_map:
|
||||
return res_map[resolution]
|
||||
else:
|
||||
# Default to 1920x1080 if unknown
|
||||
return {"width": 1920, "height": 1080}
|
||||
|
||||
def parseControlParameter(self, value):
|
||||
control_map = {
|
||||
"Motion Transfer": "motion_control",
|
||||
"Canny": "canny_control",
|
||||
"Pose Transfer": "pose_control",
|
||||
"Depth": "depth_control"
|
||||
}
|
||||
if value in control_map:
|
||||
return control_map[value]
|
||||
else:
|
||||
return control_map["Motion Transfer"]
|
||||
|
||||
def get_response(
|
||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> MoonvalleyPromptResponse:
|
||||
return poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, MoonvalleyTextToVideoRequest, "prompt_text",
|
||||
multiline=True
|
||||
),
|
||||
"negative_prompt": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="gopro, bright, contrast, static, overexposed, bright, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, contrast, saturated, vibrant, glowing, cross dissolve, texture, videogame, saturation, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, transition, dissolve, cross-dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring, static",
|
||||
),
|
||||
|
||||
"resolution": (IO.COMBO, {
|
||||
"options": ["16:9 (1920 x 1080)",
|
||||
"9:16 (1080 x 1920)",
|
||||
"1:1 (1152 x 1152)",
|
||||
"4:3 (1440 x 1080)",
|
||||
"3:4 (1080 x 1440)",
|
||||
"21:9 (2560 x 1080)"],
|
||||
"default": "16:9 (1920 x 1080)",
|
||||
"tooltip": "Resolution of the output video",
|
||||
}),
|
||||
# "length": (IO.COMBO,{"options":['5s','10s'], "default": '5s'}),
|
||||
"prompt_adherence": model_field_to_node_input(IO.FLOAT,MoonvalleyTextToVideoInferenceParams,"guidance_scale",default=7.0, step=1, min=1, max=20),
|
||||
"seed": model_field_to_node_input(IO.INT,MoonvalleyTextToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True),
|
||||
"steps": model_field_to_node_input(IO.INT, MoonvalleyTextToVideoInferenceParams, "steps", default=100, min=1, max=100),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
"optional": {
|
||||
"image": model_field_to_node_input(
|
||||
IO.IMAGE,
|
||||
MoonvalleyTextToVideoRequest,
|
||||
"image_url",
|
||||
tooltip="The reference image used to generate the video",
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "generate"
|
||||
CATEGORY = "api node/video/Moonvalley Marey"
|
||||
API_NODE = True
|
||||
|
||||
def generate(self, **kwargs):
|
||||
return None
|
||||
|
||||
# --- MoonvalleyImg2VideoNode ---
|
||||
class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return super().INPUT_TYPES()
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
RETURN_NAMES = ("video",)
|
||||
DESCRIPTION = "Moonvalley Marey Image to Video Node"
|
||||
|
||||
def generate(self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs):
|
||||
image = kwargs.get("image", None)
|
||||
if (image is None):
|
||||
raise MoonvalleyApiError("image is required")
|
||||
total_frames = get_total_frames_from_length()
|
||||
|
||||
validate_input_image(image,True)
|
||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
|
||||
|
||||
inference_params=MoonvalleyTextToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=kwargs.get("steps"),
|
||||
seed=kwargs.get("seed"),
|
||||
guidance_scale=kwargs.get("prompt_adherence"),
|
||||
num_frames=total_frames,
|
||||
width=width_height.get("width"),
|
||||
height=width_height.get("height"),
|
||||
use_negative_prompts=True
|
||||
)
|
||||
"""Upload image to comfy backend to have a URL available for further processing"""
|
||||
# Get MIME type from tensor - assuming PNG format for image tensors
|
||||
mime_type = "image/png"
|
||||
|
||||
image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type)[0]
|
||||
|
||||
request = MoonvalleyTextToVideoRequest(
|
||||
image_url=image_url,
|
||||
prompt_text=prompt,
|
||||
inference_params=inference_params
|
||||
)
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=MoonvalleyTextToVideoRequest,
|
||||
response_model=MoonvalleyPromptResponse
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
video = download_url_to_video_output(final_response.output_url)
|
||||
return (video, )
|
||||
|
||||
# --- MoonvalleyVid2VidNode ---
|
||||
class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
input_types = super().INPUT_TYPES()
|
||||
for param in ["resolution", "image"]:
|
||||
if param in input_types["required"]:
|
||||
del input_types["required"][param]
|
||||
if param in input_types["optional"]:
|
||||
del input_types["optional"][param]
|
||||
input_types["optional"] = {
|
||||
"video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Input a 5s video for 128 frames and a 10s video for 256 frames. Longer videos will be trimmed automatically."}),
|
||||
"control_type": (
|
||||
["Motion Transfer", "Pose Transfer"],
|
||||
{"default": "Motion Transfer"},
|
||||
),
|
||||
"motion_intensity": (
|
||||
"INT",
|
||||
{
|
||||
"default": 100,
|
||||
"step": 1,
|
||||
"min": 0,
|
||||
"max": 100,
|
||||
"tooltip": "Only used if control_type is 'Motion Transfer'",
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
return input_types
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
RETURN_NAMES = ("video",)
|
||||
|
||||
def generate(self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs):
|
||||
video = kwargs.get("video")
|
||||
num_frames = get_total_frames_from_length()
|
||||
|
||||
if not video :
|
||||
raise MoonvalleyApiError("video is required")
|
||||
|
||||
|
||||
"""Validate video input"""
|
||||
video_url=""
|
||||
if video:
|
||||
validated_video = validate_input_video(video, num_frames, False)
|
||||
video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
|
||||
|
||||
control_type = kwargs.get("control_type")
|
||||
motion_intensity = kwargs.get("motion_intensity")
|
||||
|
||||
"""Validate prompts and inference input"""
|
||||
validate_prompts(prompt, negative_prompt)
|
||||
inference_params=MoonvalleyVideoToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=kwargs.get("steps"),
|
||||
seed=kwargs.get("seed"),
|
||||
guidance_scale=kwargs.get("prompt_adherence"),
|
||||
control_params={'motion_intensity': motion_intensity}
|
||||
)
|
||||
|
||||
control = self.parseControlParameter(control_type)
|
||||
|
||||
request = MoonvalleyVideoToVideoRequest(
|
||||
control_type=control,
|
||||
video_url=video_url,
|
||||
prompt_text=prompt,
|
||||
inference_params=inference_params
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=MoonvalleyVideoToVideoRequest,
|
||||
response_model=MoonvalleyPromptResponse
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
|
||||
video = download_url_to_video_output(final_response.output_url)
|
||||
|
||||
return (video, )
|
||||
|
||||
# --- MoonvalleyTxt2VideoNode ---
|
||||
class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
RETURN_NAMES = ("video",)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
input_types = super().INPUT_TYPES()
|
||||
# Remove image-specific parameters
|
||||
for param in ["image"]:
|
||||
if param in input_types["optional"]:
|
||||
del input_types["optional"][param]
|
||||
return input_types
|
||||
|
||||
def generate(self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs):
|
||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
|
||||
num_frames = get_total_frames_from_length()
|
||||
|
||||
inference_params=MoonvalleyTextToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=kwargs.get("steps"),
|
||||
seed=kwargs.get("seed"),
|
||||
guidance_scale=kwargs.get("prompt_adherence"),
|
||||
num_frames=num_frames,
|
||||
width=width_height.get("width"),
|
||||
height=width_height.get("height"),
|
||||
)
|
||||
request = MoonvalleyTextToVideoRequest(
|
||||
prompt_text=prompt,
|
||||
inference_params=inference_params
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=MoonvalleyTextToVideoRequest,
|
||||
response_model=MoonvalleyPromptResponse
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_creation_response = initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = self.get_response(
|
||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||
)
|
||||
|
||||
video = download_url_to_video_output(final_response.output_url)
|
||||
return (video, )
|
||||
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"MoonvalleyImg2VideoNode": MoonvalleyImg2VideoNode,
|
||||
"MoonvalleyTxt2VideoNode": MoonvalleyTxt2VideoNode,
|
||||
# "MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode,
|
||||
}
|
||||
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"MoonvalleyImg2VideoNode": "Moonvalley Marey Image to Video",
|
||||
"MoonvalleyTxt2VideoNode": "Moonvalley Marey Text to Video",
|
||||
# "MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video",
|
||||
}
|
||||
|
||||
def get_total_frames_from_length(length="5s"):
|
||||
# if length == '5s':
|
||||
# return 128
|
||||
# elif length == '10s':
|
||||
# return 256
|
||||
return 128
|
||||
# else:
|
||||
# raise MoonvalleyApiError("length is required")
|
||||
@@ -1,6 +1,7 @@
|
||||
import itertools
|
||||
from typing import Sequence, Mapping, Dict
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import nodes
|
||||
|
||||
@@ -16,12 +17,13 @@ def include_unique_id_in_input(class_type: str) -> bool:
|
||||
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
|
||||
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||
|
||||
class CacheKeySet:
|
||||
class CacheKeySet(ABC):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.keys = {}
|
||||
self.subcache_keys = {}
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
@abstractmethod
|
||||
async def add_keys(self, node_ids):
|
||||
raise NotImplementedError()
|
||||
|
||||
def all_node_ids(self):
|
||||
@@ -60,9 +62,8 @@ class CacheKeySetID(CacheKeySet):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
self.dynprompt = dynprompt
|
||||
self.add_keys(node_ids)
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
async def add_keys(self, node_ids):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
@@ -77,37 +78,36 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
self.dynprompt = dynprompt
|
||||
self.is_changed_cache = is_changed_cache
|
||||
self.add_keys(node_ids)
|
||||
|
||||
def include_node_id_in_input(self) -> bool:
|
||||
return False
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
async def add_keys(self, node_ids):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
if not self.dynprompt.has_node(node_id):
|
||||
continue
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
||||
self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id)
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
|
||||
def get_node_signature(self, dynprompt, node_id):
|
||||
async def get_node_signature(self, dynprompt, node_id):
|
||||
signature = []
|
||||
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
||||
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||
for ancestor_id in ancestors:
|
||||
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||
return to_hashable(signature)
|
||||
|
||||
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||
if not dynprompt.has_node(node_id):
|
||||
# This node doesn't exist -- we can't cache it.
|
||||
return [float("NaN")]
|
||||
node = dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
signature = [class_type, self.is_changed_cache.get(node_id)]
|
||||
signature = [class_type, await self.is_changed_cache.get(node_id)]
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||
signature.append(node_id)
|
||||
inputs = node["inputs"]
|
||||
@@ -150,9 +150,10 @@ class BasicCache:
|
||||
self.cache = {}
|
||||
self.subcaches = {}
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.dynprompt = dynprompt
|
||||
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
||||
await self.cache_key_set.add_keys(node_ids)
|
||||
self.is_changed_cache = is_changed_cache
|
||||
self.initialized = True
|
||||
|
||||
@@ -201,13 +202,13 @@ class BasicCache:
|
||||
else:
|
||||
return None
|
||||
|
||||
def _ensure_subcache(self, node_id, children_ids):
|
||||
async def _ensure_subcache(self, node_id, children_ids):
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
subcache = self.subcaches.get(subcache_key, None)
|
||||
if subcache is None:
|
||||
subcache = BasicCache(self.key_class)
|
||||
self.subcaches[subcache_key] = subcache
|
||||
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||
return subcache
|
||||
|
||||
def _get_subcache(self, node_id):
|
||||
@@ -259,10 +260,10 @@ class HierarchicalCache(BasicCache):
|
||||
assert cache is not None
|
||||
cache._set_immediate(node_id, value)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
return cache._ensure_subcache(node_id, children_ids)
|
||||
return await cache._ensure_subcache(node_id, children_ids)
|
||||
|
||||
class LRUCache(BasicCache):
|
||||
def __init__(self, key_class, max_size=100):
|
||||
@@ -273,8 +274,8 @@ class LRUCache(BasicCache):
|
||||
self.used_generation = {}
|
||||
self.children = {}
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
self.generation += 1
|
||||
for node_id in node_ids:
|
||||
self._mark_used(node_id)
|
||||
@@ -303,11 +304,11 @@ class LRUCache(BasicCache):
|
||||
self._mark_used(node_id)
|
||||
return self._set_immediate(node_id, value)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
# Just uses subcaches for tracking 'live' nodes
|
||||
super()._ensure_subcache(node_id, children_ids)
|
||||
await super()._ensure_subcache(node_id, children_ids)
|
||||
|
||||
self.cache_key_set.add_keys(children_ids)
|
||||
await self.cache_key_set.add_keys(children_ids)
|
||||
self._mark_used(node_id)
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.children[cache_key] = []
|
||||
@@ -337,7 +338,7 @@ class DependencyAwareCache(BasicCache):
|
||||
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
|
||||
self.executed_nodes = set() # Tracks nodes that have been executed
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
"""
|
||||
Clear the entire cache and rebuild the dependency graph.
|
||||
|
||||
@@ -354,7 +355,7 @@ class DependencyAwareCache(BasicCache):
|
||||
self.executed_nodes.clear()
|
||||
|
||||
# Call the parent method to initialize the cache with the new prompt
|
||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
|
||||
# Rebuild the dependency graph
|
||||
self._build_dependency_graph(dynprompt, node_ids)
|
||||
@@ -405,7 +406,7 @@ class DependencyAwareCache(BasicCache):
|
||||
"""
|
||||
return self._get_immediate(node_id)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
"""
|
||||
Ensure a subcache exists for a node and update dependencies.
|
||||
|
||||
@@ -416,7 +417,7 @@ class DependencyAwareCache(BasicCache):
|
||||
Returns:
|
||||
The subcache object for the node.
|
||||
"""
|
||||
subcache = super()._ensure_subcache(node_id, children_ids)
|
||||
subcache = await super()._ensure_subcache(node_id, children_ids)
|
||||
for child_id in children_ids:
|
||||
self.descendants[node_id].add(child_id)
|
||||
self.ancestors[child_id].add(node_id)
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
from typing import Type, Literal
|
||||
|
||||
import nodes
|
||||
import asyncio
|
||||
from comfy_execution.graph_utils import is_link
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||
|
||||
@@ -100,6 +101,8 @@ class TopologicalSort:
|
||||
self.pendingNodes = {}
|
||||
self.blockCount = {} # Number of nodes this node is directly blocked by
|
||||
self.blocking = {} # Which nodes are blocked by this node
|
||||
self.externalBlocks = 0
|
||||
self.unblockedEvent = asyncio.Event()
|
||||
|
||||
def get_input_info(self, unique_id, input_name):
|
||||
class_type = self.dynprompt.get_node(unique_id)["class_type"]
|
||||
@@ -153,6 +156,16 @@ class TopologicalSort:
|
||||
for link in links:
|
||||
self.add_strong_link(*link)
|
||||
|
||||
def add_external_block(self, node_id):
|
||||
assert node_id in self.blockCount, "Can't add external block to a node that isn't pending"
|
||||
self.externalBlocks += 1
|
||||
self.blockCount[node_id] += 1
|
||||
def unblock():
|
||||
self.externalBlocks -= 1
|
||||
self.blockCount[node_id] -= 1
|
||||
self.unblockedEvent.set()
|
||||
return unblock
|
||||
|
||||
def is_cached(self, node_id):
|
||||
return False
|
||||
|
||||
@@ -181,11 +194,16 @@ class ExecutionList(TopologicalSort):
|
||||
def is_cached(self, node_id):
|
||||
return self.output_cache.get(node_id) is not None
|
||||
|
||||
def stage_node_execution(self):
|
||||
async def stage_node_execution(self):
|
||||
assert self.staged_node_id is None
|
||||
if self.is_empty():
|
||||
return None, None, None
|
||||
available = self.get_ready_nodes()
|
||||
while len(available) == 0 and self.externalBlocks > 0:
|
||||
# Wait for an external block to be released
|
||||
await self.unblockedEvent.wait()
|
||||
self.unblockedEvent.clear()
|
||||
available = self.get_ready_nodes()
|
||||
if len(available) == 0:
|
||||
cycled_nodes = self.get_nodes_in_cycle()
|
||||
# Because cycles composed entirely of static nodes are caught during initial validation,
|
||||
|
||||
288
comfy_execution/progress.py
Normal file
288
comfy_execution/progress.py
Normal file
@@ -0,0 +1,288 @@
|
||||
from typing import TypedDict, Dict, Optional
|
||||
from typing_extensions import override
|
||||
from PIL import Image
|
||||
from enum import Enum
|
||||
from abc import ABC
|
||||
from tqdm import tqdm
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from protocol import BinaryEventTypes
|
||||
|
||||
class NodeState(Enum):
|
||||
Pending = "pending"
|
||||
Running = "running"
|
||||
Finished = "finished"
|
||||
Error = "error"
|
||||
|
||||
class NodeProgressState(TypedDict):
|
||||
"""
|
||||
A class to represent the state of a node's progress.
|
||||
"""
|
||||
state: NodeState
|
||||
value: float
|
||||
max: float
|
||||
|
||||
class ProgressHandler(ABC):
|
||||
"""
|
||||
Abstract base class for progress handlers.
|
||||
Progress handlers receive progress updates and display them in various ways.
|
||||
"""
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.enabled = True
|
||||
|
||||
def set_registry(self, registry: "ProgressRegistry"):
|
||||
pass
|
||||
|
||||
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
"""Called when a node starts processing"""
|
||||
pass
|
||||
|
||||
def update_handler(self, node_id: str, value: float, max_value: float,
|
||||
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
|
||||
"""Called when a node's progress is updated"""
|
||||
pass
|
||||
|
||||
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
"""Called when a node finishes processing"""
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
"""Called when the progress registry is reset"""
|
||||
pass
|
||||
|
||||
def enable(self):
|
||||
"""Enable this handler"""
|
||||
self.enabled = True
|
||||
|
||||
def disable(self):
|
||||
"""Disable this handler"""
|
||||
self.enabled = False
|
||||
|
||||
class CLIProgressHandler(ProgressHandler):
|
||||
"""
|
||||
Handler that displays progress using tqdm progress bars in the CLI.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__("cli")
|
||||
self.progress_bars: Dict[str, tqdm] = {}
|
||||
|
||||
@override
|
||||
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
# Create a new tqdm progress bar
|
||||
if node_id not in self.progress_bars:
|
||||
self.progress_bars[node_id] = tqdm(
|
||||
total=state["max"],
|
||||
desc=f"Node {node_id}",
|
||||
unit="steps",
|
||||
leave=True,
|
||||
position=len(self.progress_bars)
|
||||
)
|
||||
|
||||
@override
|
||||
def update_handler(self, node_id: str, value: float, max_value: float,
|
||||
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
|
||||
# Handle case where start_handler wasn't called
|
||||
if node_id not in self.progress_bars:
|
||||
self.progress_bars[node_id] = tqdm(
|
||||
total=max_value,
|
||||
desc=f"Node {node_id}",
|
||||
unit="steps",
|
||||
leave=True,
|
||||
position=len(self.progress_bars)
|
||||
)
|
||||
self.progress_bars[node_id].update(value)
|
||||
else:
|
||||
# Update existing progress bar
|
||||
if max_value != self.progress_bars[node_id].total:
|
||||
self.progress_bars[node_id].total = max_value
|
||||
# Calculate the update amount (difference from current position)
|
||||
current_position = self.progress_bars[node_id].n
|
||||
update_amount = value - current_position
|
||||
if update_amount > 0:
|
||||
self.progress_bars[node_id].update(update_amount)
|
||||
|
||||
@override
|
||||
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
# Complete and close the progress bar if it exists
|
||||
if node_id in self.progress_bars:
|
||||
# Ensure the bar shows 100% completion
|
||||
remaining = state["max"] - self.progress_bars[node_id].n
|
||||
if remaining > 0:
|
||||
self.progress_bars[node_id].update(remaining)
|
||||
self.progress_bars[node_id].close()
|
||||
del self.progress_bars[node_id]
|
||||
|
||||
@override
|
||||
def reset(self):
|
||||
# Close all progress bars
|
||||
for bar in self.progress_bars.values():
|
||||
bar.close()
|
||||
self.progress_bars.clear()
|
||||
|
||||
class WebUIProgressHandler(ProgressHandler):
|
||||
"""
|
||||
Handler that sends progress updates to the WebUI via WebSockets.
|
||||
"""
|
||||
def __init__(self, server_instance):
|
||||
super().__init__("webui")
|
||||
self.server_instance = server_instance
|
||||
|
||||
def set_registry(self, registry: "ProgressRegistry"):
|
||||
self.registry = registry
|
||||
|
||||
def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]):
|
||||
"""Send the current progress state to the client"""
|
||||
if self.server_instance is None:
|
||||
return
|
||||
|
||||
# Only send info for non-pending nodes
|
||||
active_nodes = {
|
||||
node_id: {
|
||||
"value": state["value"],
|
||||
"max": state["max"],
|
||||
"state": state["state"].value,
|
||||
"node_id": node_id,
|
||||
"prompt_id": prompt_id,
|
||||
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
||||
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id)
|
||||
}
|
||||
for node_id, state in nodes.items()
|
||||
if state["state"] != NodeState.Pending
|
||||
}
|
||||
|
||||
# Send a combined progress_state message with all node states
|
||||
self.server_instance.send_sync("progress_state", {
|
||||
"prompt_id": prompt_id,
|
||||
"nodes": active_nodes
|
||||
})
|
||||
|
||||
@override
|
||||
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
# Send progress state of all nodes
|
||||
if self.registry:
|
||||
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||
|
||||
@override
|
||||
def update_handler(self, node_id: str, value: float, max_value: float,
|
||||
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
|
||||
# Send progress state of all nodes
|
||||
if self.registry:
|
||||
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||
if image:
|
||||
metadata = {
|
||||
"node_id": node_id,
|
||||
"prompt_id": prompt_id,
|
||||
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
||||
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id)
|
||||
}
|
||||
self.server_instance.send_sync(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, (image, metadata), self.server_instance.client_id)
|
||||
|
||||
|
||||
@override
|
||||
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||
# Send progress state of all nodes
|
||||
if self.registry:
|
||||
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||
|
||||
class ProgressRegistry:
|
||||
"""
|
||||
Registry that maintains node progress state and notifies registered handlers.
|
||||
"""
|
||||
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt):
|
||||
self.prompt_id = prompt_id
|
||||
self.dynprompt = dynprompt
|
||||
self.nodes: Dict[str, NodeProgressState] = {}
|
||||
self.handlers: Dict[str, ProgressHandler] = {}
|
||||
|
||||
def register_handler(self, handler: ProgressHandler) -> None:
|
||||
"""Register a progress handler"""
|
||||
self.handlers[handler.name] = handler
|
||||
|
||||
def unregister_handler(self, handler_name: str) -> None:
|
||||
"""Unregister a progress handler"""
|
||||
if handler_name in self.handlers:
|
||||
# Allow handler to clean up resources
|
||||
self.handlers[handler_name].reset()
|
||||
del self.handlers[handler_name]
|
||||
|
||||
def enable_handler(self, handler_name: str) -> None:
|
||||
"""Enable a progress handler"""
|
||||
if handler_name in self.handlers:
|
||||
self.handlers[handler_name].enable()
|
||||
|
||||
def disable_handler(self, handler_name: str) -> None:
|
||||
"""Disable a progress handler"""
|
||||
if handler_name in self.handlers:
|
||||
self.handlers[handler_name].disable()
|
||||
|
||||
def ensure_entry(self, node_id: str) -> NodeProgressState:
|
||||
"""Ensure a node entry exists"""
|
||||
if node_id not in self.nodes:
|
||||
self.nodes[node_id] = NodeProgressState(
|
||||
state = NodeState.Pending,
|
||||
value = 0,
|
||||
max = 1
|
||||
)
|
||||
return self.nodes[node_id]
|
||||
|
||||
def start_progress(self, node_id: str) -> None:
|
||||
"""Start progress tracking for a node"""
|
||||
entry = self.ensure_entry(node_id)
|
||||
entry["state"] = NodeState.Running
|
||||
entry["value"] = 0.0
|
||||
entry["max"] = 1.0
|
||||
|
||||
# Notify all enabled handlers
|
||||
for handler in self.handlers.values():
|
||||
if handler.enabled:
|
||||
handler.start_handler(node_id, entry, self.prompt_id)
|
||||
|
||||
def update_progress(self, node_id: str, value: float, max_value: float, image: Optional[Image.Image]) -> None:
|
||||
"""Update progress for a node"""
|
||||
entry = self.ensure_entry(node_id)
|
||||
entry["state"] = NodeState.Running
|
||||
entry["value"] = value
|
||||
entry["max"] = max_value
|
||||
|
||||
# Notify all enabled handlers
|
||||
for handler in self.handlers.values():
|
||||
if handler.enabled:
|
||||
handler.update_handler(node_id, value, max_value, entry, self.prompt_id, image)
|
||||
|
||||
def finish_progress(self, node_id: str) -> None:
|
||||
"""Finish progress tracking for a node"""
|
||||
entry = self.ensure_entry(node_id)
|
||||
entry["state"] = NodeState.Finished
|
||||
entry["value"] = entry["max"]
|
||||
|
||||
# Notify all enabled handlers
|
||||
for handler in self.handlers.values():
|
||||
if handler.enabled:
|
||||
handler.finish_handler(node_id, entry, self.prompt_id)
|
||||
|
||||
def reset_handlers(self) -> None:
|
||||
"""Reset all handlers"""
|
||||
for handler in self.handlers.values():
|
||||
handler.reset()
|
||||
|
||||
# Global registry instance
|
||||
global_progress_registry: ProgressRegistry = ProgressRegistry(prompt_id="", dynprompt=DynamicPrompt({}))
|
||||
|
||||
def reset_progress_state(prompt_id: str, dynprompt: DynamicPrompt) -> None:
|
||||
global global_progress_registry
|
||||
|
||||
# Reset existing handlers if registry exists
|
||||
if global_progress_registry is not None:
|
||||
global_progress_registry.reset_handlers()
|
||||
|
||||
# Create new registry
|
||||
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
|
||||
|
||||
def add_progress_handler(handler: ProgressHandler) -> None:
|
||||
handler.set_registry(global_progress_registry)
|
||||
global_progress_registry.register_handler(handler)
|
||||
|
||||
def get_progress_state() -> ProgressRegistry:
|
||||
return global_progress_registry
|
||||
46
comfy_execution/utils.py
Normal file
46
comfy_execution/utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import contextvars
|
||||
from typing import Optional, NamedTuple
|
||||
|
||||
class ExecutionContext(NamedTuple):
|
||||
"""
|
||||
Context information about the currently executing node.
|
||||
|
||||
Attributes:
|
||||
node_id: The ID of the currently executing node
|
||||
list_index: The index in a list being processed (for operations on batches/lists)
|
||||
"""
|
||||
prompt_id: str
|
||||
node_id: str
|
||||
list_index: Optional[int]
|
||||
|
||||
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
|
||||
|
||||
def get_executing_context() -> Optional[ExecutionContext]:
|
||||
return current_executing_context.get(None)
|
||||
|
||||
class CurrentNodeContext:
|
||||
"""
|
||||
Context manager for setting the current executing node context.
|
||||
|
||||
Sets the current_executing_context on enter and resets it on exit.
|
||||
|
||||
Example:
|
||||
with CurrentNodeContext(node_id="123", list_index=0):
|
||||
# Code that should run with the current node context set
|
||||
process_image()
|
||||
"""
|
||||
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
|
||||
self.context = ExecutionContext(
|
||||
prompt_id= prompt_id,
|
||||
node_id= node_id,
|
||||
list_index= list_index
|
||||
)
|
||||
self.token = None
|
||||
|
||||
def __enter__(self):
|
||||
self.token = current_executing_context.set(self.context)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.token is not None:
|
||||
current_executing_context.reset(self.token)
|
||||
@@ -133,6 +133,14 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
|
||||
if sample_rate != audio["sample_rate"]:
|
||||
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
||||
|
||||
# Create in-memory WAV buffer
|
||||
wav_buffer = io.BytesIO()
|
||||
torchaudio.save(wav_buffer, waveform, sample_rate, format="WAV")
|
||||
wav_buffer.seek(0) # Rewind for reading
|
||||
|
||||
# Use PyAV to convert and add metadata
|
||||
input_container = av.open(wav_buffer)
|
||||
|
||||
# Create output with specified format
|
||||
output_buffer = io.BytesIO()
|
||||
output_container = av.open(output_buffer, mode='w', format=format)
|
||||
@@ -142,6 +150,7 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
|
||||
output_container.metadata[key] = value
|
||||
|
||||
# Set up the output stream with appropriate properties
|
||||
input_container.streams.audio[0]
|
||||
if format == "opus":
|
||||
out_stream = output_container.add_stream("libopus", rate=sample_rate)
|
||||
if quality == "64k":
|
||||
@@ -166,16 +175,18 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
|
||||
else: #format == "flac":
|
||||
out_stream = output_container.add_stream("flac", rate=sample_rate)
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
|
||||
# Copy frames from input to output
|
||||
for frame in input_container.decode(audio=0):
|
||||
frame.pts = None # Let PyAV handle timestamps
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
|
||||
# Flush encoder
|
||||
output_container.mux(out_stream.encode(None))
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
input_container.close()
|
||||
|
||||
# Write the output to file
|
||||
output_buffer.seek(0)
|
||||
|
||||
@@ -583,49 +583,6 @@ class GetImageSize:
|
||||
|
||||
return width, height, batch_size
|
||||
|
||||
class ImageRotate:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": (IO.IMAGE,),
|
||||
"rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],),
|
||||
}}
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "rotate"
|
||||
|
||||
CATEGORY = "image/transform"
|
||||
|
||||
def rotate(self, image, rotation):
|
||||
rotate_by = 0
|
||||
if rotation.startswith("90"):
|
||||
rotate_by = 1
|
||||
elif rotation.startswith("180"):
|
||||
rotate_by = 2
|
||||
elif rotation.startswith("270"):
|
||||
rotate_by = 3
|
||||
|
||||
image = torch.rot90(image, k=rotate_by, dims=[2, 1])
|
||||
return (image,)
|
||||
|
||||
class ImageFlip:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": (IO.IMAGE,),
|
||||
"flip_method": (["x-axis: vertically", "y-axis: horizontally"],),
|
||||
}}
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "flip"
|
||||
|
||||
CATEGORY = "image/transform"
|
||||
|
||||
def flip(self, image, flip_method):
|
||||
if flip_method.startswith("x"):
|
||||
image = torch.flip(image, dims=[1])
|
||||
elif flip_method.startswith("y"):
|
||||
image = torch.flip(image, dims=[2])
|
||||
|
||||
return (image,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageCrop": ImageCrop,
|
||||
"RepeatImageBatch": RepeatImageBatch,
|
||||
@@ -637,6 +594,4 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ImageStitch": ImageStitch,
|
||||
"ResizeAndPadImage": ResizeAndPadImage,
|
||||
"GetImageSize": GetImageSize,
|
||||
"ImageRotate": ImageRotate,
|
||||
"ImageFlip": ImageFlip,
|
||||
}
|
||||
|
||||
@@ -5,8 +5,6 @@ import os
|
||||
from comfy.comfy_types import IO
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def normalize_path(path):
|
||||
return path.replace('\\', '/')
|
||||
@@ -18,14 +16,7 @@ class Load3D():
|
||||
|
||||
os.makedirs(input_dir, exist_ok=True)
|
||||
|
||||
input_path = Path(input_dir)
|
||||
base_path = Path(folder_paths.get_input_directory())
|
||||
|
||||
files = [
|
||||
normalize_path(str(file_path.relative_to(base_path)))
|
||||
for file_path in input_path.rglob("*")
|
||||
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
|
||||
]
|
||||
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.fbx', '.stl'))]
|
||||
|
||||
return {"required": {
|
||||
"model_file": (sorted(files), {"file_upload": True}),
|
||||
@@ -70,14 +61,7 @@ class Load3DAnimation():
|
||||
|
||||
os.makedirs(input_dir, exist_ok=True)
|
||||
|
||||
input_path = Path(input_dir)
|
||||
base_path = Path(folder_paths.get_input_directory())
|
||||
|
||||
files = [
|
||||
normalize_path(str(file_path.relative_to(base_path)))
|
||||
for file_path in input_path.rglob("*")
|
||||
if file_path.suffix.lower() in {'.gltf', '.glb', '.fbx'}
|
||||
]
|
||||
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.fbx'))]
|
||||
|
||||
return {"required": {
|
||||
"model_file": (sorted(files), {"file_upload": True}),
|
||||
|
||||
@@ -134,8 +134,8 @@ class LTXVAddGuide:
|
||||
_, num_keyframes = get_keyframe_idxs(cond)
|
||||
latent_count = latent_length - num_keyframes
|
||||
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||
if guide_length > 1 and frame_idx != 0:
|
||||
frame_idx = (frame_idx - 1) // time_scale_factor * time_scale_factor + 1 # frame index - 1 must be divisible by 8 or frame_idx == 0
|
||||
if guide_length > 1:
|
||||
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
|
||||
|
||||
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
|
||||
|
||||
@@ -144,7 +144,7 @@ class LTXVAddGuide:
|
||||
def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors):
|
||||
keyframe_idxs, _ = get_keyframe_idxs(cond)
|
||||
_, latent_coords = self._patchifier.patchify(guiding_latent)
|
||||
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0
|
||||
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, True)
|
||||
pixel_coords[:, 0] += frame_idx
|
||||
if keyframe_idxs is None:
|
||||
keyframe_idxs = pixel_coords
|
||||
|
||||
@@ -152,7 +152,7 @@ class ImageColorToMask:
|
||||
def image_to_mask(self, image, color):
|
||||
temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
|
||||
temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
|
||||
mask = torch.where(temp == color, 1.0, 0).float()
|
||||
mask = torch.where(temp == color, 255, 0).float()
|
||||
return (mask,)
|
||||
|
||||
class SolidMask:
|
||||
|
||||
@@ -78,75 +78,7 @@ class SkipLayerGuidanceDiT:
|
||||
|
||||
return (m, )
|
||||
|
||||
class SkipLayerGuidanceDiTSimple:
|
||||
'''
|
||||
Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.
|
||||
'''
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model": ("MODEL", ),
|
||||
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
||||
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "skip_guidance"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
DESCRIPTION = "Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass."
|
||||
|
||||
CATEGORY = "advanced/guidance"
|
||||
|
||||
def skip_guidance(self, model, start_percent, end_percent, double_layers="", single_layers=""):
|
||||
def skip(args, extra_args):
|
||||
return args
|
||||
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||
|
||||
double_layers = re.findall(r'\d+', double_layers)
|
||||
double_layers = [int(i) for i in double_layers]
|
||||
|
||||
single_layers = re.findall(r'\d+', single_layers)
|
||||
single_layers = [int(i) for i in single_layers]
|
||||
|
||||
if len(double_layers) == 0 and len(single_layers) == 0:
|
||||
return (model, )
|
||||
|
||||
def calc_cond_batch_function(args):
|
||||
x = args["input"]
|
||||
model = args["model"]
|
||||
conds = args["conds"]
|
||||
sigma = args["sigma"]
|
||||
|
||||
model_options = args["model_options"]
|
||||
slg_model_options = model_options.copy()
|
||||
|
||||
for layer in double_layers:
|
||||
slg_model_options = comfy.model_patcher.set_model_options_patch_replace(slg_model_options, skip, "dit", "double_block", layer)
|
||||
|
||||
for layer in single_layers:
|
||||
slg_model_options = comfy.model_patcher.set_model_options_patch_replace(slg_model_options, skip, "dit", "single_block", layer)
|
||||
|
||||
cond, uncond = conds
|
||||
sigma_ = sigma[0].item()
|
||||
if sigma_ >= sigma_end and sigma_ <= sigma_start and uncond is not None:
|
||||
cond_out, _ = comfy.samplers.calc_cond_batch(model, [cond, None], x, sigma, model_options)
|
||||
_, uncond_out = comfy.samplers.calc_cond_batch(model, [None, uncond], x, sigma, slg_model_options)
|
||||
out = [cond_out, uncond_out]
|
||||
else:
|
||||
out = comfy.samplers.calc_cond_batch(model, conds, x, sigma, model_options)
|
||||
|
||||
return out
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function)
|
||||
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
|
||||
"SkipLayerGuidanceDiTSimple": SkipLayerGuidanceDiTSimple,
|
||||
}
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.44"
|
||||
__version__ = "0.3.43"
|
||||
|
||||
135
execution.py
135
execution.py
@@ -8,12 +8,14 @@ import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import List, Literal, NamedTuple, Optional
|
||||
import asyncio
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
import nodes
|
||||
from comfy_execution.caching import (
|
||||
BasicCache,
|
||||
CacheKeySetID,
|
||||
CacheKeySetInputSignature,
|
||||
DependencyAwareCache,
|
||||
@@ -28,6 +30,8 @@ from comfy_execution.graph import (
|
||||
)
|
||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
@@ -39,12 +43,13 @@ class DuplicateNodeError(Exception):
|
||||
pass
|
||||
|
||||
class IsChangedCache:
|
||||
def __init__(self, dynprompt, outputs_cache):
|
||||
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache):
|
||||
self.prompt_id = prompt_id
|
||||
self.dynprompt = dynprompt
|
||||
self.outputs_cache = outputs_cache
|
||||
self.is_changed = {}
|
||||
|
||||
def get(self, node_id):
|
||||
async def get(self, node_id):
|
||||
if node_id in self.is_changed:
|
||||
return self.is_changed[node_id]
|
||||
|
||||
@@ -62,7 +67,8 @@ class IsChangedCache:
|
||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
try:
|
||||
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, "IS_CHANGED")
|
||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
except Exception as e:
|
||||
logging.warning("WARNING: {}".format(e))
|
||||
@@ -164,7 +170,19 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
||||
|
||||
map_node_over_list = None #Don't hook this please
|
||||
|
||||
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
async def resolve_map_node_over_list_results(results):
|
||||
remaining = [x for x in results if isinstance(x, asyncio.Task) and not x.done()]
|
||||
if len(remaining) == 0:
|
||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||
else:
|
||||
done, pending = await asyncio.wait(remaining)
|
||||
for task in done:
|
||||
exc = task.exception()
|
||||
if exc is not None:
|
||||
raise exc
|
||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||
|
||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
# check if node wants the lists
|
||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||
|
||||
@@ -178,7 +196,7 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
|
||||
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
|
||||
|
||||
results = []
|
||||
def process_inputs(inputs, index=None, input_is_list=False):
|
||||
async def process_inputs(inputs, index=None, input_is_list=False):
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
execution_block = None
|
||||
@@ -194,20 +212,37 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
|
||||
if execution_block is None:
|
||||
if pre_execute_cb is not None and index is not None:
|
||||
pre_execute_cb(index)
|
||||
results.append(getattr(obj, func)(**inputs))
|
||||
f = getattr(obj, func)
|
||||
if inspect.iscoroutinefunction(f):
|
||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
||||
with CurrentNodeContext(prompt_id, unique_id, list_index):
|
||||
return await f(**args)
|
||||
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
|
||||
# Give the task a chance to execute without yielding
|
||||
await asyncio.sleep(0)
|
||||
if task.done():
|
||||
result = task.result()
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(task)
|
||||
else:
|
||||
with CurrentNodeContext(prompt_id, unique_id, index):
|
||||
result = f(**inputs)
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(execution_block)
|
||||
|
||||
if input_is_list:
|
||||
process_inputs(input_data_all, 0, input_is_list=input_is_list)
|
||||
await process_inputs(input_data_all, 0, input_is_list=input_is_list)
|
||||
elif max_len_input == 0:
|
||||
process_inputs({})
|
||||
await process_inputs({})
|
||||
else:
|
||||
for i in range(max_len_input):
|
||||
input_dict = slice_dict(input_data_all, i)
|
||||
process_inputs(input_dict, i)
|
||||
await process_inputs(input_dict, i)
|
||||
return results
|
||||
|
||||
|
||||
def merge_result_data(results, obj):
|
||||
# check which outputs need concatenating
|
||||
output = []
|
||||
@@ -229,11 +264,18 @@ def merge_result_data(results, obj):
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||
if has_pending_task:
|
||||
return return_values, {}, False, has_pending_task
|
||||
output, ui, has_subgraph = get_output_from_returns(return_values, obj)
|
||||
return output, ui, has_subgraph, False
|
||||
|
||||
def get_output_from_returns(return_values, obj):
|
||||
results = []
|
||||
uis = []
|
||||
subgraph_results = []
|
||||
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
has_subgraph = False
|
||||
for i in range(len(return_values)):
|
||||
r = return_values[i]
|
||||
@@ -267,6 +309,10 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
|
||||
else:
|
||||
output = []
|
||||
ui = dict()
|
||||
# TODO: Think there's an existing bug here
|
||||
# If we're performing a subgraph expansion, we probably shouldn't be returning UI values yet.
|
||||
# They'll get cached without the completed subgraphs. It's an edge case and I'm not aware of
|
||||
# any nodes that use both subgraph expansion and custom UI outputs, but might be a problem in the future.
|
||||
if len(uis) > 0:
|
||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||
return output, ui, has_subgraph
|
||||
@@ -279,7 +325,7 @@ def format_value(x):
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
|
||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
|
||||
unique_id = current_item
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||
@@ -291,11 +337,26 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
if server.client_id is not None:
|
||||
cached_output = caches.ui.get(unique_id) or {}
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||
get_progress_state().finish_progress(unique_id)
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
input_data_all = None
|
||||
try:
|
||||
if unique_id in pending_subgraph_results:
|
||||
if unique_id in pending_async_nodes:
|
||||
results = []
|
||||
for r in pending_async_nodes[unique_id]:
|
||||
if isinstance(r, asyncio.Task):
|
||||
try:
|
||||
results.append(r.result())
|
||||
except Exception as ex:
|
||||
# An async task failed - propagate the exception up
|
||||
del pending_async_nodes[unique_id]
|
||||
raise ex
|
||||
else:
|
||||
results.append(r)
|
||||
del pending_async_nodes[unique_id]
|
||||
output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def)
|
||||
elif unique_id in pending_subgraph_results:
|
||||
cached_results = pending_subgraph_results[unique_id]
|
||||
resolved_outputs = []
|
||||
for is_subgraph, result in cached_results:
|
||||
@@ -317,6 +378,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
output_ui = []
|
||||
has_subgraph = False
|
||||
else:
|
||||
get_progress_state().start_progress(unique_id)
|
||||
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
@@ -328,7 +390,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
caches.objects.set(unique_id, obj)
|
||||
|
||||
if hasattr(obj, "check_lazy_status"):
|
||||
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||
x not in input_data_all or x in missing_keys
|
||||
@@ -357,8 +420,18 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
else:
|
||||
return block
|
||||
def pre_execute_cb(call_index):
|
||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
if has_pending_tasks:
|
||||
pending_async_nodes[unique_id] = output_data
|
||||
unblock = execution_list.add_external_block(unique_id)
|
||||
async def await_completion():
|
||||
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
unblock()
|
||||
asyncio.create_task(await_completion())
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
if len(output_ui) > 0:
|
||||
caches.ui.set(unique_id, {
|
||||
"meta": {
|
||||
@@ -401,7 +474,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
cached_outputs.append((True, node_outputs))
|
||||
new_node_ids = set(new_node_ids)
|
||||
for cache in caches.all:
|
||||
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
|
||||
subcache = await cache.ensure_subcache_for(unique_id, new_node_ids)
|
||||
subcache.clean_unused()
|
||||
for node_id in new_output_ids:
|
||||
execution_list.add_node(node_id)
|
||||
for link in new_output_links:
|
||||
@@ -446,6 +520,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
|
||||
return (ExecutionResult.FAILURE, error_details, ex)
|
||||
|
||||
get_progress_state().finish_progress(unique_id)
|
||||
executed.add(unique_id)
|
||||
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
@@ -500,6 +575,11 @@ class PromptExecutor:
|
||||
self.add_message("execution_error", mes, broadcast=False)
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
asyncio_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||
|
||||
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
nodes.interrupt_processing(False)
|
||||
|
||||
if "client_id" in extra_data:
|
||||
@@ -512,9 +592,11 @@ class PromptExecutor:
|
||||
|
||||
with torch.inference_mode():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
add_progress_handler(WebUIProgressHandler(self.server))
|
||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
cache.clean_unused()
|
||||
|
||||
cached_nodes = []
|
||||
@@ -527,6 +609,7 @@ class PromptExecutor:
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
executed = set()
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
@@ -534,12 +617,13 @@ class PromptExecutor:
|
||||
execution_list.add_node(node_id)
|
||||
|
||||
while not execution_list.is_empty():
|
||||
node_id, error, ex = execution_list.stage_node_execution()
|
||||
node_id, error, ex = await execution_list.stage_node_execution()
|
||||
if error is not None:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
||||
assert node_id is not None, "Node ID should not be None at this point"
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
@@ -569,7 +653,7 @@ class PromptExecutor:
|
||||
comfy.model_management.unload_all_models()
|
||||
|
||||
|
||||
def validate_inputs(prompt, item, validated):
|
||||
async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
unique_id = item
|
||||
if unique_id in validated:
|
||||
return validated[unique_id]
|
||||
@@ -646,7 +730,7 @@ def validate_inputs(prompt, item, validated):
|
||||
errors.append(error)
|
||||
continue
|
||||
try:
|
||||
r = validate_inputs(prompt, o_id, validated)
|
||||
r = await validate_inputs(prompt_id, prompt, o_id, validated)
|
||||
if r[0] is False:
|
||||
# `r` will be set in `validated[o_id]` already
|
||||
valid = False
|
||||
@@ -771,7 +855,8 @@ def validate_inputs(prompt, item, validated):
|
||||
input_filtered['input_types'] = [received_types]
|
||||
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
||||
ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||
ret = await resolve_map_node_over_list_results(ret)
|
||||
for x in input_filtered:
|
||||
for i, r in enumerate(ret):
|
||||
if r is not True and not isinstance(r, ExecutionBlocker):
|
||||
@@ -804,7 +889,7 @@ def full_type_name(klass):
|
||||
return klass.__qualname__
|
||||
return module + '.' + klass.__qualname__
|
||||
|
||||
def validate_prompt(prompt):
|
||||
async def validate_prompt(prompt_id, prompt):
|
||||
outputs = set()
|
||||
for x in prompt:
|
||||
if 'class_type' not in prompt[x]:
|
||||
@@ -847,7 +932,7 @@ def validate_prompt(prompt):
|
||||
valid = False
|
||||
reasons = []
|
||||
try:
|
||||
m = validate_inputs(prompt, o, validated)
|
||||
m = await validate_inputs(prompt_id, prompt, o, validated)
|
||||
valid = m[0]
|
||||
reasons = m[1]
|
||||
except Exception as ex:
|
||||
|
||||
21
main.py
21
main.py
@@ -11,6 +11,8 @@ import itertools
|
||||
import utils.extra_config
|
||||
import logging
|
||||
import sys
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
|
||||
if __name__ == "__main__":
|
||||
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
||||
@@ -131,7 +133,7 @@ import comfy.utils
|
||||
|
||||
import execution
|
||||
import server
|
||||
from server import BinaryEventTypes
|
||||
from protocol import BinaryEventTypes
|
||||
import nodes
|
||||
import comfy.model_management
|
||||
import comfyui_version
|
||||
@@ -227,14 +229,25 @@ async def run(server_instance, address='', port=8188, verbose=True, call_on_star
|
||||
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
|
||||
)
|
||||
|
||||
|
||||
def hijack_progress(server_instance):
|
||||
def hook(value, total, preview_image):
|
||||
def hook(value, total, preview_image, prompt_id=None, node_id=None):
|
||||
executing_context = get_executing_context()
|
||||
if prompt_id is None and executing_context is not None:
|
||||
prompt_id = executing_context.prompt_id
|
||||
if node_id is None and executing_context is not None:
|
||||
node_id = executing_context.node_id
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
|
||||
if prompt_id is None:
|
||||
prompt_id = server_instance.last_prompt_id
|
||||
if node_id is None:
|
||||
node_id = server_instance.last_node_id
|
||||
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
|
||||
get_progress_state().update_progress(node_id, value, total, preview_image)
|
||||
|
||||
server_instance.send_sync("progress", progress, server_instance.client_id)
|
||||
if preview_image is not None:
|
||||
# Also send old method for backward compatibility
|
||||
# TODO - Remove after this repo is updated to frontend with metadata support
|
||||
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
|
||||
|
||||
comfy.utils.set_progress_bar_global_hook(hook)
|
||||
|
||||
1
nodes.py
1
nodes.py
@@ -2310,7 +2310,6 @@ def init_builtin_api_nodes():
|
||||
"nodes_pika.py",
|
||||
"nodes_runway.py",
|
||||
"nodes_tripo.py",
|
||||
"nodes_moonvalley.py",
|
||||
"nodes_rodin.py",
|
||||
"nodes_gemini.py",
|
||||
]
|
||||
|
||||
7
protocol.py
Normal file
7
protocol.py
Normal file
@@ -0,0 +1,7 @@
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
TEXT = 3
|
||||
PREVIEW_IMAGE_WITH_METADATA = 4
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.44"
|
||||
version = "0.3.43"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.23.4
|
||||
comfyui-workflow-templates==0.1.35
|
||||
comfyui-embedded-docs==0.2.4
|
||||
comfyui-workflow-templates==0.1.31
|
||||
comfyui-embedded-docs==0.2.3
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
||||
51
server.py
51
server.py
@@ -35,11 +35,7 @@ from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
TEXT = 3
|
||||
from protocol import BinaryEventTypes
|
||||
|
||||
async def send_socket_catch_exception(function, message):
|
||||
try:
|
||||
@@ -643,7 +639,8 @@ class PromptServer():
|
||||
|
||||
if "prompt" in json_data:
|
||||
prompt = json_data["prompt"]
|
||||
valid = execution.validate_prompt(prompt)
|
||||
prompt_id = str(uuid.uuid4())
|
||||
valid = await execution.validate_prompt(prompt_id, prompt)
|
||||
extra_data = {}
|
||||
if "extra_data" in json_data:
|
||||
extra_data = json_data["extra_data"]
|
||||
@@ -651,7 +648,6 @@ class PromptServer():
|
||||
if "client_id" in json_data:
|
||||
extra_data["client_id"] = json_data["client_id"]
|
||||
if valid[0]:
|
||||
prompt_id = str(uuid.uuid4())
|
||||
outputs_to_execute = valid[2]
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||||
@@ -766,6 +762,10 @@ class PromptServer():
|
||||
async def send(self, event, data, sid=None):
|
||||
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||
await self.send_image(data, sid=sid)
|
||||
elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA:
|
||||
# data is (preview_image, metadata)
|
||||
preview_image, metadata = data
|
||||
await self.send_image_with_metadata(preview_image, metadata, sid=sid)
|
||||
elif isinstance(data, (bytes, bytearray)):
|
||||
await self.send_bytes(event, data, sid)
|
||||
else:
|
||||
@@ -804,6 +804,43 @@ class PromptServer():
|
||||
preview_bytes = bytesIO.getvalue()
|
||||
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
|
||||
|
||||
async def send_image_with_metadata(self, image_data, metadata=None, sid=None):
|
||||
image_type = image_data[0]
|
||||
image = image_data[1]
|
||||
max_size = image_data[2]
|
||||
if max_size is not None:
|
||||
if hasattr(Image, 'Resampling'):
|
||||
resampling = Image.Resampling.BILINEAR
|
||||
else:
|
||||
resampling = Image.Resampling.LANCZOS
|
||||
|
||||
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
||||
|
||||
mimetype = "image/png" if image_type == "PNG" else "image/jpeg"
|
||||
|
||||
# Prepare metadata
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata["image_type"] = mimetype
|
||||
|
||||
# Serialize metadata as JSON
|
||||
import json
|
||||
metadata_json = json.dumps(metadata).encode('utf-8')
|
||||
metadata_length = len(metadata_json)
|
||||
|
||||
# Prepare image data
|
||||
bytesIO = BytesIO()
|
||||
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
|
||||
image_bytes = bytesIO.getvalue()
|
||||
|
||||
# Combine metadata and image
|
||||
combined_data = bytearray()
|
||||
combined_data.extend(struct.pack(">I", metadata_length))
|
||||
combined_data.extend(metadata_json)
|
||||
combined_data.extend(image_bytes)
|
||||
|
||||
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, combined_data, sid=sid)
|
||||
|
||||
async def send_bytes(self, event, data, sid=None):
|
||||
message = self.encode_bytes(event, data)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pytest>=7.8.0
|
||||
pytest-aiohttp
|
||||
pytest-asyncio
|
||||
websocket-client
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Config for testing nodes
|
||||
testing:
|
||||
custom_nodes: tests/inference/testing_nodes
|
||||
custom_nodes: testing_nodes
|
||||
|
||||
|
||||
410
tests/inference/test_async_nodes.py
Normal file
410
tests/inference/test_async_nodes.py
Normal file
@@ -0,0 +1,410 @@
|
||||
import pytest
|
||||
import time
|
||||
import torch
|
||||
import urllib.error
|
||||
import numpy as np
|
||||
import subprocess
|
||||
|
||||
from pytest import fixture
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from tests.inference.test_execution import ComfyClient
|
||||
|
||||
|
||||
@pytest.mark.execution
|
||||
class TestAsyncNodes:
|
||||
@fixture(scope="class", autouse=True, params=[
|
||||
(False, 0),
|
||||
(True, 0),
|
||||
(True, 100),
|
||||
])
|
||||
def _server(self, args_pytest, request):
|
||||
pargs = [
|
||||
'python','main.py',
|
||||
'--output-directory', args_pytest["output_dir"],
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
||||
]
|
||||
use_lru, lru_size = request.param
|
||||
if use_lru:
|
||||
pargs += ['--cache-lru', str(lru_size)]
|
||||
# Running server with args: pargs
|
||||
p = subprocess.Popen(pargs)
|
||||
yield
|
||||
p.kill()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@fixture(scope="class", autouse=True)
|
||||
def shared_client(self, args_pytest, _server):
|
||||
client = ComfyClient()
|
||||
n_tries = 5
|
||||
for i in range(n_tries):
|
||||
time.sleep(4)
|
||||
try:
|
||||
client.connect(listen=args_pytest["listen"], port=args_pytest["port"])
|
||||
except ConnectionRefusedError:
|
||||
# Retrying...
|
||||
pass
|
||||
else:
|
||||
break
|
||||
yield client
|
||||
del client
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@fixture
|
||||
def client(self, shared_client, request):
|
||||
shared_client.set_test_name(f"async_nodes[{request.node.name}]")
|
||||
yield shared_client
|
||||
|
||||
@fixture
|
||||
def builder(self, request):
|
||||
yield GraphBuilder(prefix=request.node.name)
|
||||
|
||||
# Happy Path Tests
|
||||
|
||||
def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that a basic async node executes correctly."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1)
|
||||
output = g.node("SaveImage", images=sleep_node.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify execution completed
|
||||
assert result.did_run(sleep_node), "Async sleep node should have executed"
|
||||
assert result.did_run(output), "Output node should have executed"
|
||||
|
||||
# Verify the image passed through correctly
|
||||
result_images = result.get_images(output)
|
||||
assert len(result_images) == 1, "Should have 1 image"
|
||||
assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black"
|
||||
|
||||
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that multiple async nodes execute in parallel."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create multiple async sleep nodes with different durations
|
||||
sleep1 = g.node("TestSleep", value=image.out(0), seconds=0.3)
|
||||
sleep2 = g.node("TestSleep", value=image.out(0), seconds=0.4)
|
||||
sleep3 = g.node("TestSleep", value=image.out(0), seconds=0.5)
|
||||
|
||||
# Add outputs for each
|
||||
_output1 = g.node("PreviewImage", images=sleep1.out(0))
|
||||
_output2 = g.node("PreviewImage", images=sleep2.out(0))
|
||||
_output3 = g.node("PreviewImage", images=sleep3.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should take ~0.5s (max duration) not 1.2s (sum of durations)
|
||||
assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s"
|
||||
|
||||
# Verify all nodes executed
|
||||
assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3)
|
||||
|
||||
def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes with proper dependency handling."""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Chain of async operations
|
||||
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
|
||||
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)
|
||||
|
||||
# Average depends on both async results
|
||||
average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0))
|
||||
output = g.node("SaveImage", images=average.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify execution order
|
||||
assert result.did_run(sleep1) and result.did_run(sleep2)
|
||||
assert result.did_run(average) and result.did_run(output)
|
||||
|
||||
# Verify averaged result
|
||||
result_images = result.get_images(output)
|
||||
avg_value = np.array(result_images[0]).mean()
|
||||
assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5"
|
||||
|
||||
def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async VALIDATE_INPUTS function."""
|
||||
g = builder
|
||||
# Create a test node with async validation
|
||||
validation_node = g.node("TestAsyncValidation", value=5.0, threshold=10.0)
|
||||
g.node("SaveImage", images=validation_node.out(0))
|
||||
|
||||
# Should pass validation
|
||||
result = client.run(g)
|
||||
assert result.did_run(validation_node)
|
||||
|
||||
# Test validation failure
|
||||
validation_node.inputs['threshold'] = 3.0 # Will fail since value > threshold
|
||||
with pytest.raises(urllib.error.HTTPError):
|
||||
client.run(g)
|
||||
|
||||
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes with lazy evaluation."""
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1)
|
||||
|
||||
# Create async nodes that will be evaluated lazily
|
||||
sleep1 = g.node("TestSleep", value=input1.out(0), seconds=0.3)
|
||||
sleep2 = g.node("TestSleep", value=input2.out(0), seconds=0.3)
|
||||
|
||||
# Use lazy mix that only needs sleep1 (mask=0.0)
|
||||
lazy_mix = g.node("TestLazyMixImages", image1=sleep1.out(0), image2=sleep2.out(0), mask=mask.out(0))
|
||||
g.node("SaveImage", images=lazy_mix.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should only execute sleep1, not sleep2
|
||||
assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s"
|
||||
assert result.did_run(sleep1), "Sleep1 should have executed"
|
||||
assert not result.did_run(sleep2), "Sleep2 should have been skipped"
|
||||
|
||||
def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async check_lazy_status function."""
|
||||
g = builder
|
||||
# Create a node with async check_lazy_status
|
||||
lazy_node = g.node("TestAsyncLazyCheck",
|
||||
input1="value1",
|
||||
input2="value2",
|
||||
condition=True)
|
||||
g.node("SaveImage", images=lazy_node.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
assert result.did_run(lazy_node)
|
||||
|
||||
# Error Handling Tests
|
||||
|
||||
def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that async execution errors are properly handled."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
# Create an async node that will error
|
||||
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
|
||||
g.node("SaveImage", images=error_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised an error"
|
||||
except Exception as e:
|
||||
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
|
||||
assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node"
|
||||
|
||||
def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async validation error handling."""
|
||||
g = builder
|
||||
# Node with async validation that will fail
|
||||
validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0)
|
||||
g.node("SaveImage", images=validation_node.out(0))
|
||||
|
||||
with pytest.raises(urllib.error.HTTPError) as exc_info:
|
||||
client.run(g)
|
||||
# Verify it's a validation error
|
||||
assert exc_info.value.code == 400
|
||||
|
||||
def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test handling of async operations that timeout."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
# Very long sleep that would timeout
|
||||
timeout_node = g.node("TestAsyncTimeout", value=image.out(0), timeout=0.5, operation_time=2.0)
|
||||
g.node("SaveImage", images=timeout_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised a timeout error"
|
||||
except Exception as e:
|
||||
assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}"
|
||||
|
||||
def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that workflow can recover after async errors."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# First run with error
|
||||
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
|
||||
g.node("SaveImage", images=error_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
except Exception:
|
||||
pass # Expected
|
||||
|
||||
# Second run should succeed
|
||||
g2 = GraphBuilder(prefix="recovery_test")
|
||||
image2 = g2.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1)
|
||||
g2.node("SaveImage", images=sleep_node.out(0))
|
||||
|
||||
result = client.run(g2)
|
||||
assert result.did_run(sleep_node), "Should be able to run after error"
|
||||
|
||||
def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test handling when sync node errors while async node is executing."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Async node that takes time
|
||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.5)
|
||||
|
||||
# Sync node that will error immediately
|
||||
error_node = g.node("TestSyncError", value=image.out(0))
|
||||
|
||||
# Both feed into output
|
||||
g.node("PreviewImage", images=sleep_node.out(0))
|
||||
g.node("PreviewImage", images=error_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised an error"
|
||||
except Exception as e:
|
||||
# Verify the sync error was caught even though async was running
|
||||
assert 'prompt_id' in e.args[0]
|
||||
|
||||
# Edge Cases
|
||||
|
||||
def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes with execution blockers."""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Async sleep nodes
|
||||
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
|
||||
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)
|
||||
|
||||
# Create list of images
|
||||
image_list = g.node("TestMakeListNode", value1=sleep1.out(0), value2=sleep2.out(0))
|
||||
|
||||
# Create list of blocking conditions - [False, True] to block only the second item
|
||||
int1 = g.node("StubInt", value=1)
|
||||
int2 = g.node("StubInt", value=2)
|
||||
block_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0))
|
||||
|
||||
# Compare each value against 2, so first is False (1 != 2) and second is True (2 == 2)
|
||||
compare = g.node("TestIntConditions", a=block_list.out(0), b=2, operation="==")
|
||||
|
||||
# Block based on the comparison results
|
||||
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
|
||||
|
||||
output = g.node("PreviewImage", images=blocker.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 1, "Should have blocked second image"
|
||||
|
||||
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that async nodes are properly cached."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
|
||||
g.node("SaveImage", images=sleep_node.out(0))
|
||||
|
||||
# First run
|
||||
result1 = client.run(g)
|
||||
assert result1.did_run(sleep_node), "Should run first time"
|
||||
|
||||
# Second run - should be cached
|
||||
start_time = time.time()
|
||||
result2 = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
assert not result2.did_run(sleep_node), "Should be cached"
|
||||
assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant"
|
||||
|
||||
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes within dynamically generated prompts."""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Node that generates async nodes dynamically
|
||||
dynamic_async = g.node("TestDynamicAsyncGeneration",
|
||||
image1=image1.out(0),
|
||||
image2=image2.out(0),
|
||||
num_async_nodes=3,
|
||||
sleep_duration=0.2)
|
||||
g.node("SaveImage", images=dynamic_async.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should execute async nodes in parallel within dynamic prompt
|
||||
assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s"
|
||||
assert result.did_run(dynamic_async)
|
||||
|
||||
def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that async resources are properly cleaned up."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create multiple async nodes that use resources
|
||||
resource_nodes = []
|
||||
for i in range(5):
|
||||
node = g.node("TestAsyncResourceUser",
|
||||
value=image.out(0),
|
||||
resource_id=f"resource_{i}",
|
||||
duration=0.1)
|
||||
resource_nodes.append(node)
|
||||
g.node("PreviewImage", images=node.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify all nodes executed
|
||||
for node in resource_nodes:
|
||||
assert result.did_run(node)
|
||||
|
||||
# Run again to ensure resources were cleaned up
|
||||
result2 = client.run(g)
|
||||
# Should be cached but not error due to resource conflicts
|
||||
for node in resource_nodes:
|
||||
assert not result2.did_run(node), "Should be cached"
|
||||
|
||||
def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test cancellation of async operations."""
|
||||
# This would require implementing cancellation in the client
|
||||
# For now, we'll test that long-running async operations can be interrupted
|
||||
pass # TODO: Implement when cancellation API is available
|
||||
|
||||
def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test workflows with both sync and async nodes."""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||
|
||||
# Mix of sync and async operations
|
||||
# Sync: lazy mix images
|
||||
sync_op1 = g.node("TestLazyMixImages", image1=image1.out(0), image2=image2.out(0), mask=mask.out(0))
|
||||
# Async: sleep
|
||||
async_op1 = g.node("TestSleep", value=sync_op1.out(0), seconds=0.2)
|
||||
# Sync: custom validation
|
||||
sync_op2 = g.node("TestCustomValidation1", input1=async_op1.out(0), input2=0.5)
|
||||
# Async: sleep again
|
||||
async_op2 = g.node("TestSleep", value=sync_op2.out(0), seconds=0.2)
|
||||
|
||||
output = g.node("SaveImage", images=async_op2.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify all nodes executed in correct order
|
||||
assert result.did_run(sync_op1)
|
||||
assert result.did_run(async_op1)
|
||||
assert result.did_run(sync_op2)
|
||||
assert result.did_run(async_op2)
|
||||
|
||||
# Image should be a mix of black and white (gray)
|
||||
result_images = result.get_images(output)
|
||||
avg_value = np.array(result_images[0]).mean()
|
||||
assert abs(avg_value - 63.75) < 5, f"Average value {avg_value} should be ~63.75"
|
||||
@@ -252,7 +252,7 @@ class TestExecution:
|
||||
|
||||
@pytest.mark.parametrize("test_type, test_value", [
|
||||
("StubInt", 5),
|
||||
("StubFloat", 5.0)
|
||||
("StubMask", 5.0)
|
||||
])
|
||||
def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
@@ -497,6 +497,69 @@ class TestExecution:
|
||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||
assert not result.did_run(test_node), "The execution should have been cached"
|
||||
|
||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create sleep nodes for each duration
|
||||
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8)
|
||||
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9)
|
||||
sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0)
|
||||
|
||||
# Add outputs to verify the execution
|
||||
_output1 = g.node("PreviewImage", images=sleep_node1.out(0))
|
||||
_output2 = g.node("PreviewImage", images=sleep_node2.out(0))
|
||||
_output3 = g.node("PreviewImage", images=sleep_node3.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# The test should take around 0.4 seconds (the longest sleep duration)
|
||||
# plus some overhead, but definitely less than the sum of all sleeps (0.9s)
|
||||
# We'll allow for up to 0.8s total to account for overhead
|
||||
assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s"
|
||||
|
||||
# Verify that all nodes executed
|
||||
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
|
||||
assert result.did_run(sleep_node2), "Sleep node 2 should have run"
|
||||
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
|
||||
|
||||
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
# Create input images with different values
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
image3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create a TestParallelSleep node that expands into multiple TestSleep nodes
|
||||
parallel_sleep = g.node("TestParallelSleep",
|
||||
image1=image1.out(0),
|
||||
image2=image2.out(0),
|
||||
image3=image3.out(0),
|
||||
sleep1=0.4,
|
||||
sleep2=0.5,
|
||||
sleep3=0.6)
|
||||
output = g.node("SaveImage", images=parallel_sleep.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Similar to the previous test, expect parallel execution of the sleep nodes
|
||||
# which should complete in less than the sum of all sleeps
|
||||
assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s"
|
||||
|
||||
# Verify the parallel sleep node executed
|
||||
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
|
||||
|
||||
# Verify we get an image as output (blend of the three input images)
|
||||
result_images = result.get_images(output)
|
||||
assert len(result_images) == 1, "Should have 1 image"
|
||||
# Average pixel value should be around 170 (255 * 2 // 3)
|
||||
avg_value = numpy.array(result_images[0]).mean()
|
||||
assert avg_value == 170, f"Image average value {avg_value} should be 170"
|
||||
|
||||
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
|
||||
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
|
||||
# only that one entry in the list is blocked.
|
||||
|
||||
@@ -3,6 +3,7 @@ from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DI
|
||||
from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .async_test_nodes import ASYNC_TEST_NODE_CLASS_MAPPINGS, ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
|
||||
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
@@ -13,6 +14,7 @@ NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(ASYNC_TEST_NODE_CLASS_MAPPINGS)
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
@@ -20,4 +22,5 @@ NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
|
||||
|
||||
343
tests/inference/testing_nodes/testing-pack/async_test_nodes.py
Normal file
343
tests/inference/testing_nodes/testing-pack/async_test_nodes.py
Normal file
@@ -0,0 +1,343 @@
|
||||
import torch
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
from comfy.utils import ProgressBar
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||
from comfy.comfy_types import IO
|
||||
|
||||
|
||||
class TestAsyncValidation(ComfyNodeABC):
|
||||
"""Test node with async VALIDATE_INPUTS."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 5.0}),
|
||||
"threshold": ("FLOAT", {"default": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
@classmethod
|
||||
async def VALIDATE_INPUTS(cls, value, threshold):
|
||||
# Simulate async validation (e.g., checking remote service)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
if value > threshold:
|
||||
return f"Value {value} exceeds threshold {threshold}"
|
||||
return True
|
||||
|
||||
def process(self, value, threshold):
|
||||
# Create image based on value
|
||||
intensity = value / 10.0
|
||||
image = torch.ones([1, 512, 512, 3]) * intensity
|
||||
return (image,)
|
||||
|
||||
|
||||
class TestAsyncError(ComfyNodeABC):
|
||||
"""Test node that errors during async execution."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"error_after": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "error_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def error_execution(self, value, error_after):
|
||||
await asyncio.sleep(error_after)
|
||||
raise RuntimeError("Intentional async execution error for testing")
|
||||
|
||||
|
||||
class TestAsyncValidationError(ComfyNodeABC):
|
||||
"""Test node with async validation that always fails."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 5.0}),
|
||||
"max_value": ("FLOAT", {"default": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
@classmethod
|
||||
async def VALIDATE_INPUTS(cls, value, max_value):
|
||||
await asyncio.sleep(0.05)
|
||||
# Always fail validation for values > max_value
|
||||
if value > max_value:
|
||||
return f"Async validation failed: {value} > {max_value}"
|
||||
return True
|
||||
|
||||
def process(self, value, max_value):
|
||||
# This won't be reached if validation fails
|
||||
image = torch.ones([1, 512, 512, 3]) * (value / max_value)
|
||||
return (image,)
|
||||
|
||||
|
||||
class TestAsyncTimeout(ComfyNodeABC):
|
||||
"""Test node that simulates timeout scenarios."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"timeout": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0}),
|
||||
"operation_time": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "timeout_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def timeout_execution(self, value, timeout, operation_time):
|
||||
try:
|
||||
# This will timeout if operation_time > timeout
|
||||
await asyncio.wait_for(asyncio.sleep(operation_time), timeout=timeout)
|
||||
return (value,)
|
||||
except asyncio.TimeoutError:
|
||||
raise RuntimeError(f"Operation timed out after {timeout} seconds")
|
||||
|
||||
|
||||
class TestSyncError(ComfyNodeABC):
|
||||
"""Test node that errors synchronously (for mixed sync/async testing)."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "sync_error"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
def sync_error(self, value):
|
||||
raise RuntimeError("Intentional sync execution error for testing")
|
||||
|
||||
|
||||
class TestAsyncLazyCheck(ComfyNodeABC):
|
||||
"""Test node with async check_lazy_status."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": (IO.ANY, {"lazy": True}),
|
||||
"input2": (IO.ANY, {"lazy": True}),
|
||||
"condition": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def check_lazy_status(self, condition, input1, input2):
|
||||
# Simulate async checking (e.g., querying remote service)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
needed = []
|
||||
if condition and input1 is None:
|
||||
needed.append("input1")
|
||||
if not condition and input2 is None:
|
||||
needed.append("input2")
|
||||
return needed
|
||||
|
||||
def process(self, input1, input2, condition):
|
||||
# Return a simple image
|
||||
return (torch.ones([1, 512, 512, 3]),)
|
||||
|
||||
|
||||
class TestDynamicAsyncGeneration(ComfyNodeABC):
|
||||
"""Test node that dynamically generates async nodes."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE",),
|
||||
"image2": ("IMAGE",),
|
||||
"num_async_nodes": ("INT", {"default": 3, "min": 1, "max": 10}),
|
||||
"sleep_duration": ("FLOAT", {"default": 0.2, "min": 0.1, "max": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "generate_async_workflow"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration):
|
||||
g = GraphBuilder()
|
||||
|
||||
# Create multiple async sleep nodes
|
||||
sleep_nodes = []
|
||||
for i in range(num_async_nodes):
|
||||
image = image1 if i % 2 == 0 else image2
|
||||
sleep_node = g.node("TestSleep", value=image, seconds=sleep_duration)
|
||||
sleep_nodes.append(sleep_node)
|
||||
|
||||
# Average all results
|
||||
if len(sleep_nodes) == 1:
|
||||
final_node = sleep_nodes[0]
|
||||
else:
|
||||
avg_inputs = {"input1": sleep_nodes[0].out(0)}
|
||||
for i, node in enumerate(sleep_nodes[1:], 2):
|
||||
avg_inputs[f"input{i}"] = node.out(0)
|
||||
final_node = g.node("TestVariadicAverage", **avg_inputs)
|
||||
|
||||
return {
|
||||
"result": (final_node.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
|
||||
class TestAsyncResourceUser(ComfyNodeABC):
|
||||
"""Test node that uses resources during async execution."""
|
||||
|
||||
# Class-level resource tracking for testing
|
||||
_active_resources: Dict[str, bool] = {}
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"resource_id": ("STRING", {"default": "resource_0"}),
|
||||
"duration": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "use_resource"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def use_resource(self, value, resource_id, duration):
|
||||
# Check if resource is already in use
|
||||
if self._active_resources.get(resource_id, False):
|
||||
raise RuntimeError(f"Resource {resource_id} is already in use!")
|
||||
|
||||
# Mark resource as in use
|
||||
self._active_resources[resource_id] = True
|
||||
|
||||
try:
|
||||
# Simulate resource usage
|
||||
await asyncio.sleep(duration)
|
||||
return (value,)
|
||||
finally:
|
||||
# Always clean up resource
|
||||
self._active_resources[resource_id] = False
|
||||
|
||||
|
||||
class TestAsyncBatchProcessing(ComfyNodeABC):
|
||||
"""Test async processing of batched inputs."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE",),
|
||||
"process_time_per_item": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 1.0}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process_batch"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def process_batch(self, images, process_time_per_item, unique_id):
|
||||
batch_size = images.shape[0]
|
||||
pbar = ProgressBar(batch_size, node_id=unique_id)
|
||||
|
||||
# Process each image in the batch
|
||||
processed = []
|
||||
for i in range(batch_size):
|
||||
# Simulate async processing
|
||||
await asyncio.sleep(process_time_per_item)
|
||||
|
||||
# Simple processing: invert the image
|
||||
processed_image = 1.0 - images[i:i+1]
|
||||
processed.append(processed_image)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
# Stack processed images
|
||||
result = torch.cat(processed, dim=0)
|
||||
return (result,)
|
||||
|
||||
|
||||
class TestAsyncConcurrentLimit(ComfyNodeABC):
|
||||
"""Test concurrent execution limits for async nodes."""
|
||||
|
||||
_semaphore = asyncio.Semaphore(2) # Only allow 2 concurrent executions
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"duration": ("FLOAT", {"default": 0.5, "min": 0.1, "max": 2.0}),
|
||||
"node_id": ("INT", {"default": 0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "limited_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def limited_execution(self, value, duration, node_id):
|
||||
async with self._semaphore:
|
||||
# Node {node_id} acquired semaphore
|
||||
await asyncio.sleep(duration)
|
||||
# Node {node_id} releasing semaphore
|
||||
return (value,)
|
||||
|
||||
|
||||
# Add node mappings
|
||||
ASYNC_TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestAsyncValidation": TestAsyncValidation,
|
||||
"TestAsyncError": TestAsyncError,
|
||||
"TestAsyncValidationError": TestAsyncValidationError,
|
||||
"TestAsyncTimeout": TestAsyncTimeout,
|
||||
"TestSyncError": TestSyncError,
|
||||
"TestAsyncLazyCheck": TestAsyncLazyCheck,
|
||||
"TestDynamicAsyncGeneration": TestDynamicAsyncGeneration,
|
||||
"TestAsyncResourceUser": TestAsyncResourceUser,
|
||||
"TestAsyncBatchProcessing": TestAsyncBatchProcessing,
|
||||
"TestAsyncConcurrentLimit": TestAsyncConcurrentLimit,
|
||||
}
|
||||
|
||||
ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestAsyncValidation": "Test Async Validation",
|
||||
"TestAsyncError": "Test Async Error",
|
||||
"TestAsyncValidationError": "Test Async Validation Error",
|
||||
"TestAsyncTimeout": "Test Async Timeout",
|
||||
"TestSyncError": "Test Sync Error",
|
||||
"TestAsyncLazyCheck": "Test Async Lazy Check",
|
||||
"TestDynamicAsyncGeneration": "Test Dynamic Async Generation",
|
||||
"TestAsyncResourceUser": "Test Async Resource User",
|
||||
"TestAsyncBatchProcessing": "Test Async Batch Processing",
|
||||
"TestAsyncConcurrentLimit": "Test Async Concurrent Limit",
|
||||
}
|
||||
@@ -1,6 +1,11 @@
|
||||
import torch
|
||||
import time
|
||||
import asyncio
|
||||
from comfy.utils import ProgressBar
|
||||
from .tools import VariantSupport
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||
from comfy.comfy_types import IO
|
||||
|
||||
class TestLazyMixImages:
|
||||
@classmethod
|
||||
@@ -333,6 +338,131 @@ class TestMixedExpansionReturns:
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
class TestSamplingInExpansion:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"clip": ("CLIP",),
|
||||
"vae": ("VAE",),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 100}),
|
||||
"cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 30.0}),
|
||||
"prompt": ("STRING", {"multiline": True, "default": "a beautiful landscape with mountains and trees"}),
|
||||
"negative_prompt": ("STRING", {"multiline": True, "default": "blurry, bad quality, worst quality"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "sampling_in_expansion"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def sampling_in_expansion(self, model, clip, vae, seed, steps, cfg, prompt, negative_prompt):
|
||||
g = GraphBuilder()
|
||||
|
||||
# Create a basic image generation workflow using the input model, clip and vae
|
||||
# 1. Setup text prompts using the provided CLIP model
|
||||
positive_prompt = g.node("CLIPTextEncode",
|
||||
text=prompt,
|
||||
clip=clip)
|
||||
negative_prompt = g.node("CLIPTextEncode",
|
||||
text=negative_prompt,
|
||||
clip=clip)
|
||||
|
||||
# 2. Create empty latent with specified size
|
||||
empty_latent = g.node("EmptyLatentImage", width=512, height=512, batch_size=1)
|
||||
|
||||
# 3. Setup sampler and generate image latent
|
||||
sampler = g.node("KSampler",
|
||||
model=model,
|
||||
positive=positive_prompt.out(0),
|
||||
negative=negative_prompt.out(0),
|
||||
latent_image=empty_latent.out(0),
|
||||
seed=seed,
|
||||
steps=steps,
|
||||
cfg=cfg,
|
||||
sampler_name="euler_ancestral",
|
||||
scheduler="normal")
|
||||
|
||||
# 4. Decode latent to image using VAE
|
||||
output = g.node("VAEDecode", samples=sampler.out(0), vae=vae)
|
||||
|
||||
return {
|
||||
"result": (output.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
class TestSleep(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"seconds": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 9999.0, "step": 0.01, "tooltip": "The amount of seconds to sleep."}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "sleep"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
async def sleep(self, value, seconds, unique_id):
|
||||
pbar = ProgressBar(seconds, node_id=unique_id)
|
||||
start = time.time()
|
||||
expiration = start + seconds
|
||||
now = start
|
||||
while now < expiration:
|
||||
now = time.time()
|
||||
pbar.update_absolute(now - start)
|
||||
await asyncio.sleep(0.01)
|
||||
return (value,)
|
||||
|
||||
class TestParallelSleep(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE", ),
|
||||
"image2": ("IMAGE", ),
|
||||
"image3": ("IMAGE", ),
|
||||
"sleep1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"sleep2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"sleep3": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "parallel_sleep"
|
||||
CATEGORY = "_for_testing"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id):
|
||||
# Create a graph dynamically with three TestSleep nodes
|
||||
g = GraphBuilder()
|
||||
|
||||
# Create sleep nodes for each duration and image
|
||||
sleep_node1 = g.node("TestSleep", value=image1, seconds=sleep1)
|
||||
sleep_node2 = g.node("TestSleep", value=image2, seconds=sleep2)
|
||||
sleep_node3 = g.node("TestSleep", value=image3, seconds=sleep3)
|
||||
|
||||
# Blend the results using TestVariadicAverage
|
||||
blend = g.node("TestVariadicAverage",
|
||||
input1=sleep_node1.out(0),
|
||||
input2=sleep_node2.out(0),
|
||||
input3=sleep_node3.out(0))
|
||||
|
||||
return {
|
||||
"result": (blend.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestLazyMixImages": TestLazyMixImages,
|
||||
"TestVariadicAverage": TestVariadicAverage,
|
||||
@@ -345,6 +475,9 @@ TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestCustomValidation5": TestCustomValidation5,
|
||||
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
|
||||
"TestMixedExpansionReturns": TestMixedExpansionReturns,
|
||||
"TestSamplingInExpansion": TestSamplingInExpansion,
|
||||
"TestSleep": TestSleep,
|
||||
"TestParallelSleep": TestParallelSleep,
|
||||
}
|
||||
|
||||
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -359,4 +492,7 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestCustomValidation5": "Custom Validation 5",
|
||||
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
|
||||
"TestMixedExpansionReturns": "Mixed Expansion Returns",
|
||||
"TestSamplingInExpansion": "Sampling In Expansion",
|
||||
"TestSleep": "Test Sleep",
|
||||
"TestParallelSleep": "Test Parallel Sleep",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user