Compare commits

...

26 Commits

Author SHA1 Message Date
comfyanonymous
32a95bba8a ComfyUI version 0.3.49 2025-08-05 07:33:02 -04:00
ComfyUI Wiki
da1ad9b516 Update template to 0.1.51 (#9187) 2025-08-05 07:24:12 -04:00
comfyanonymous
d044a24398 Fix default shift and any latent size for qwen image model. (#9186) 2025-08-05 06:12:27 -04:00
ComfyUI Wiki
5be6fd09ff Update template to 0.1.48 (#9182) 2025-08-05 03:48:56 -04:00
Christian Byrne
f69609bbd6 Add Veo3 video generation node with audio support (#9110)
- Create new Veo3VideoGenerationNode that extends VeoVideoGenerationNode
- Add support for generateAudio parameter (only for Veo3 models)
- Support new Veo3 models: veo-3.0-generate-001, veo-3.0-fast-generate-001
- Fix Veo3 duration constraint to 8 seconds only
- Update original node to be clearly Veo 2 only
- Update API paths to use model parameter: /proxy/veo/{model}/generate
- Regenerate API types from staging to include generateAudio parameter
- Fix TripoModelVersion enum reference after regeneration
- Mark generated API types file in .gitattributes
2025-08-05 01:52:25 -04:00
comfyanonymous
c012400240 Initial support for qwen image model. (#9179) 2025-08-04 22:53:25 -04:00
comfyanonymous
03895dea7c Fix another issue with the PR. (#9170) 2025-08-04 04:33:04 -04:00
comfyanonymous
84f9759424 Add some warnings and prevent crash when cond devices don't match. (#9169) 2025-08-04 04:20:12 -04:00
comfyanonymous
7991341e89 Various fixes for broken things from earlier PR. (#9168) 2025-08-04 04:02:40 -04:00
comfyanonymous
140ffc7fdc Fix broken controlnet from last PR. (#9167) 2025-08-04 03:28:12 -04:00
comfyanonymous
182f90b5ec Lower cond vram use by casting at the same time as device transfer. (#9159) 2025-08-04 03:11:53 -04:00
comfyanonymous
aebac22193 Cleanup. (#9160) 2025-08-03 07:08:11 -04:00
comfyanonymous
13aaa66ec2 Make sure context is on the right device. (#9154) 2025-08-02 15:09:23 -04:00
comfyanonymous
5f582a9757 Make sure all the conds are on the right device. (#9151) 2025-08-02 15:00:13 -04:00
ComfyUI Wiki
fbcc23945d Update template to 0.1.47 (#9153) 2025-08-02 14:15:29 -04:00
Johnpaul Chiwetelu
3dfefc88d0 API for Recently Used Items (#8792)
* feat: add file creation time to model file metadata and user file info

* fix linting
2025-08-01 22:02:06 -04:00
comfyanonymous
bff60b5cfc ComfyUI version 0.3.48 2025-08-01 20:03:22 -04:00
comfyanonymous
1e638a140b Tiny wan vae optimizations. (#9136) 2025-08-01 05:25:38 -04:00
ComfyUI Wiki
4696d74305 update template to 0.1.45 (#9135) 2025-08-01 03:06:18 -04:00
comfyanonymous
5ee381c058 Fix WanFirstLastFrameToVideo node when no clip vision. (#9134) 2025-07-31 23:33:27 -04:00
Jedrzej Kosinski
4887743a2a V3 Node Schema Definition - initial (#8656) 2025-07-31 18:02:12 -04:00
comfyanonymous
97b8a2c26a More accurate explanation of release process. (#9126) 2025-07-31 05:46:23 -04:00
guill
97eb256a35 Add support for partial execution in backend (#9123)
When a prompt is submitted, it can optionally include
`partial_execution_targets` as a list of ids. If it does, rather than
adding all outputs to the execution list, we add only those in the list.
2025-07-30 22:55:28 -04:00
chaObserv
61b08d4ba6 Replace manual x * sigmoid(x) with torch silu in VAE nonlinearity (#9057) 2025-07-30 19:25:56 -04:00
comfyanonymous
da9dab7edd Small wan camera memory optimization. (#9111) 2025-07-30 05:55:26 -04:00
ComfyUI Wiki
d2aaef029c Update template to 0.1.44 (#9104) 2025-07-29 22:50:49 -04:00
40 changed files with 6002 additions and 211 deletions

1
.gitattributes vendored
View File

@@ -1,2 +1,3 @@
/web/assets/** linguist-generated
/web/** linguist-vendored
comfy_api_nodes/apis/__init__.py linguist-generated

View File

@@ -111,7 +111,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
## Release Process
ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories:
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- Releases a new stable version (e.g., v0.7.0)

View File

@@ -130,10 +130,21 @@ class ModelFileManager:
for file_name in filenames:
try:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
except:
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
full_path = os.path.join(dirpath, file_name)
relative_path = os.path.relpath(full_path, directory)
# Get file metadata
file_info = {
"name": relative_path,
"pathIndex": pathIndex,
"modified": os.path.getmtime(full_path), # Add modification time
"created": os.path.getctime(full_path), # Add creation time
"size": os.path.getsize(full_path) # Add file size
}
result.append(file_info)
except Exception as e:
logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
continue
for d in subdirs:
@@ -144,7 +155,7 @@ class ModelFileManager:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
return result, dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)

View File

@@ -20,13 +20,15 @@ class FileInfo(TypedDict):
path: str
size: int
modified: int
created: int
def get_file_info(path: str, relative_to: str) -> FileInfo:
return {
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
"size": os.path.getsize(path),
"modified": os.path.getmtime(path)
"modified": os.path.getmtime(path),
"created": os.path.getctime(path)
}

View File

@@ -1,6 +1,7 @@
import torch
import math
import comfy.utils
import logging
class CONDRegular:
@@ -10,12 +11,15 @@ class CONDRegular:
def _copy_with(self, cond):
return self.__class__(cond)
def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
def process_cond(self, batch_size, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size))
def can_concat(self, other):
if self.cond.shape != other.cond.shape:
return False
if self.cond.device != other.cond.device:
logging.warning("WARNING: conds not on same device, skipping concat.")
return False
return True
def concat(self, others):
@@ -29,14 +33,14 @@ class CONDRegular:
class CONDNoiseShape(CONDRegular):
def process_cond(self, batch_size, device, area, **kwargs):
def process_cond(self, batch_size, area, **kwargs):
data = self.cond
if area is not None:
dims = len(area) // 2
for i in range(dims):
data = data.narrow(i + 2, area[i + dims], area[i])
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size))
class CONDCrossAttn(CONDRegular):
@@ -51,6 +55,9 @@ class CONDCrossAttn(CONDRegular):
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
if self.cond.device != other.cond.device:
logging.warning("WARNING: conds not on same device: skipping concat.")
return False
return True
def concat(self, others):
@@ -73,7 +80,7 @@ class CONDConstant(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, device, **kwargs):
def process_cond(self, batch_size, **kwargs):
return self._copy_with(self.cond)
def can_concat(self, other):
@@ -92,10 +99,10 @@ class CONDList(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, device, **kwargs):
def process_cond(self, batch_size, **kwargs):
out = []
for c in self.cond:
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
out.append(comfy.utils.repeat_to_batch_size(c, batch_size))
return self._copy_with(out)

View File

@@ -28,6 +28,7 @@ import comfy.model_detection
import comfy.model_patcher
import comfy.ops
import comfy.latent_formats
import comfy.model_base
import comfy.cldm.cldm
import comfy.t2i_adapter.adapter
@@ -43,7 +44,6 @@ if TYPE_CHECKING:
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
#print(current_batch_size, target_batch_size)
if current_batch_size == 1:
return tensor
@@ -265,12 +265,12 @@ class ControlNet(ControlBase):
for c in self.extra_conds:
temp = cond.get(c, None)
if temp is not None:
extra[c] = temp.to(dtype)
extra[c] = comfy.model_base.convert_tensor(temp, dtype, x_noisy.device)
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), **extra)
return self.control_merge(control, control_prev, output_dtype=None)
def copy(self):

View File

@@ -58,7 +58,8 @@ def is_odd(n: int) -> bool:
def nonlinearity(x):
return x * torch.sigmoid(x)
# x * sigmoid(x)
return torch.nn.functional.silu(x)
def Normalize(in_channels, num_groups=32):

View File

@@ -36,7 +36,7 @@ def get_timestep_embedding(timesteps, embedding_dim):
def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
return torch.nn.functional.silu(x)
def Normalize(in_channels, num_groups=32):

View File

@@ -0,0 +1,400 @@
# https://github.com/QwenLM/Qwen-Image (Apache 2.0)
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from einops import repeat
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
import comfy.ldm.common_dit
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out, bias=bias, dtype=dtype, device=device)
self.approximate = approximate
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = F.gelu(hidden_states, approximate=self.approximate)
return hidden_states
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
inner_dim=None,
bias: bool = True,
dtype=None, device=None, operations=None
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
self.net = nn.ModuleList([])
self.net.append(GELU(dim, inner_dim, approximate="tanh", bias=bias, dtype=dtype, device=device, operations=operations))
self.net.append(nn.Dropout(dropout))
self.net.append(operations.Linear(inner_dim, dim_out, bias=bias, dtype=dtype, device=device))
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0:
return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape)
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
self.timestep_embedder = TimestepEmbedding(
in_channels=256,
time_embed_dim=embedding_dim,
dtype=dtype,
device=device,
operations=operations
)
def forward(self, timestep, hidden_states):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
return timesteps_emb
class Attention(nn.Module):
def __init__(
self,
query_dim: int,
dim_head: int = 64,
heads: int = 8,
dropout: float = 0.0,
bias: bool = False,
eps: float = 1e-5,
out_bias: bool = True,
out_dim: int = None,
out_context_dim: int = None,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim
self.heads = heads
self.dim_head = dim_head
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
self.dropout = dropout
# Q/K normalization
self.norm_q = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
# Image stream projections
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
# Text stream projections
self.add_q_proj = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
# Output projections
self.to_out = nn.ModuleList([
operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device),
nn.Dropout(dropout)
])
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
def forward(
self,
hidden_states: torch.FloatTensor, # Image stream
encoder_hidden_states: torch.FloatTensor = None, # Text stream
encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
seq_txt = encoder_hidden_states.shape[1]
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key)
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
joint_query = joint_query.flatten(start_dim=2)
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
img_attn_output = self.to_out[0](img_attn_output)
img_attn_output = self.to_out[1](img_attn_output)
txt_attn_output = self.to_add_out(txt_attn_output)
return img_attn_output, txt_attn_output
class QwenImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
eps: float = 1e-6,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.img_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.img_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.img_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.img_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
self.txt_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.txt_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.txt_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
self.attn = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
eps=eps,
dtype=dtype,
device=device,
operations=operations,
)
def _modulate(self, x, mod_params):
shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb)
txt_mod_params = self.txt_mod(temb)
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_normed = self.img_norm1(hidden_states)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
txt_normed = self.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
img_attn_output, txt_attn_output = self.attn(
hidden_states=img_modulated,
encoder_hidden_states=txt_modulated,
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
txt_normed2 = self.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
return encoder_hidden_states, hidden_states
class LastLayer(nn.Module):
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
elementwise_affine=False,
eps=1e-6,
bias=True,
dtype=None, device=None, operations=None
):
super().__init__()
self.silu = nn.SiLU()
self.linear = operations.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias, dtype=dtype, device=device)
self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine=False, bias=bias, dtype=dtype, device=device)
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(conditioning_embedding))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
class QwenImageTransformer2DModel(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 64,
out_channels: Optional[int] = 16,
num_layers: int = 60,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 3584,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
image_model=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dtype = dtype
self.patch_size = patch_size
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
self.time_text_embed = QwenTimestepProjEmbeddings(
embedding_dim=self.inner_dim,
pooled_projection_dim=pooled_projection_dim,
dtype=dtype,
device=device,
operations=operations
)
self.txt_norm = operations.RMSNorm(joint_attention_dim, eps=1e-6, dtype=dtype, device=device)
self.img_in = operations.Linear(in_channels, self.inner_dim, dtype=dtype, device=device)
self.txt_in = operations.Linear(joint_attention_dim, self.inner_dim, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList([
QwenImageTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dtype=dtype,
device=device,
operations=operations
)
for _ in range(num_layers)
])
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
self.gradient_checkpointing = False
def pos_embeds(self, x, context):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_start = round(max(h_len, w_len))
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
def forward(
self,
x,
timesteps,
context,
attention_mask=None,
guidance: torch.Tensor = None,
**kwargs
):
timestep = timesteps
encoder_hidden_states = context
encoder_hidden_states_mask = attention_mask
image_rotary_emb = self.pos_embeds(x, context)
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
hidden_states = self.img_in(hidden_states)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if guidance is not None:
guidance = guidance * 1000
temb = (
self.time_text_embed(timestep, hidden_states)
if guidance is None
else self.time_text_embed(timestep, guidance, hidden_states)
)
for block in self.transformer_blocks:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

View File

@@ -769,8 +769,7 @@ class CameraWanModel(WanModel):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
if self.control_adapter is not None and camera_conditions is not None:
x_camera = self.control_adapter(camera_conditions).to(x.dtype)
x = x + x_camera
x = x + self.control_adapter(camera_conditions).to(x.dtype)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)

View File

@@ -24,12 +24,17 @@ class CausalConv3d(ops.Conv3d):
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
if cache_list is not None:
cache_x = cache_list[cache_idx]
cache_list[cache_idx] = None
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
del cache_x
x = F.pad(x, padding)
return super().forward(x)
@@ -166,7 +171,7 @@ class ResidualBlock(nn.Module):
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
old_x = x
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
@@ -178,12 +183,12 @@ class ResidualBlock(nn.Module):
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
x = layer(x, cache_list=feat_cache, cache_idx=idx)
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
return x + self.shortcut(old_x)
class AttentionBlock(nn.Module):

View File

@@ -151,7 +151,7 @@ class ResidualBlock(nn.Module):
],
dim=2,
)
x = layer(x, feat_cache[idx])
x = layer(x, cache_list=feat_cache, cache_idx=idx)
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:

View File

@@ -42,6 +42,7 @@ import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
import comfy.model_management
import comfy.patcher_extension
@@ -106,10 +107,12 @@ def model_sampling(model_config, model_type):
return ModelSampling(model_config)
def convert_tensor(extra, dtype):
def convert_tensor(extra, dtype, device):
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
extra = comfy.model_management.cast_to_device(extra, device, dtype)
else:
extra = comfy.model_management.cast_to_device(extra, device, None)
return extra
@@ -160,7 +163,7 @@ class BaseModel(torch.nn.Module):
xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None:
xc = torch.cat([xc] + [c_concat], dim=1)
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
context = c_crossattn
dtype = self.get_dtype()
@@ -169,20 +172,21 @@ class BaseModel(torch.nn.Module):
dtype = self.manual_cast_dtype
xc = xc.to(dtype)
device = xc.device
t = self.model_sampling.timestep(t).float()
if context is not None:
context = context.to(dtype)
context = comfy.model_management.cast_to_device(context, device, dtype)
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "dtype"):
extra = convert_tensor(extra, dtype)
extra = convert_tensor(extra, dtype, device)
elif isinstance(extra, list):
ex = []
for ext in extra:
ex.append(convert_tensor(ext, dtype))
ex.append(convert_tensor(ext, dtype, device))
extra = ex
extra_conds[o] = extra
@@ -398,7 +402,7 @@ class SD21UNCLIP(BaseModel):
unclip_conditioning = kwargs.get("unclip_conditioning", None)
device = kwargs["device"]
if unclip_conditioning is None:
return torch.zeros((1, self.adm_channels))
return torch.zeros((1, self.adm_channels), device=device)
else:
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
@@ -612,9 +616,11 @@ class IP2P:
if image is None:
image = torch.zeros_like(noise)
else:
image = image.to(device=device)
if image.shape[1:] != noise.shape[1:]:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0])
return self.process_ip2p_image_in(image)
@@ -693,7 +699,7 @@ class StableCascade_B(BaseModel):
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
out["effnet"] = comfy.conds.CONDRegular(prior)
out["effnet"] = comfy.conds.CONDRegular(prior.to(device=noise.device))
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
return out
@@ -1158,10 +1164,10 @@ class WAN21_Vace(WAN21):
vace_frames_out = []
for j in range(len(vace_frames)):
vf = vace_frames[j].clone()
vf = vace_frames[j].to(device=noise.device, dtype=noise.dtype, copy=True)
for i in range(0, vf.shape[1], 16):
vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
vf = torch.cat([vf, mask[j]], dim=1)
vf = torch.cat([vf, mask[j].to(device=noise.device, dtype=noise.dtype)], dim=1)
vace_frames_out.append(vf)
vace_frames = torch.stack(vace_frames_out, dim=1)
@@ -1303,3 +1309,14 @@ class Omnigen2(BaseModel):
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
class QwenImage(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out

View File

@@ -481,6 +481,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["timestep_scale"] = 1000.0
return dit_config
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
dit_config = {}
dit_config["image_model"] = "qwen_image"
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@@ -867,7 +872,7 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
hidden_size = state_dict["x_embedder.bias"].shape[0]
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)

View File

@@ -89,7 +89,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
conditioning = {}
model_conds = conds["model_conds"]
for c in model_conds:
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], area=area)
hooks = conds.get('hooks', None)
control = conds.get('control', None)

View File

@@ -47,6 +47,7 @@ import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.model_patcher
import comfy.lora
@@ -771,6 +772,7 @@ class CLIPType(Enum):
CHROMA = 15
ACE = 16
OMNIGEN2 = 17
QWEN_IMAGE = 18
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -791,6 +793,7 @@ class TEModel(Enum):
T5_XXL_OLD = 8
GEMMA_2_2B = 9
QWEN25_3B = 10
QWEN25_7B = 11
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -812,7 +815,11 @@ def detect_te_model(sd):
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
return TEModel.QWEN25_3B
weight = sd['model.layers.0.self_attn.k_proj.bias']
if weight.shape[0] == 256:
return TEModel.QWEN25_3B
if weight.shape[0] == 512:
return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd:
return TEModel.LLAMA3_8
return None
@@ -917,6 +924,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.QWEN25_3B:
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
elif te_model == TEModel.QWEN25_7B:
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:

View File

@@ -19,6 +19,7 @@ import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
from . import supported_models_base
from . import latent_formats
@@ -1229,7 +1230,36 @@ class Omnigen2(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
class QwenImage(supported_models_base.BASE):
unet_config = {
"image_model": "qwen_image",
}
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
sampling_settings = {
"multiplier": 1.0,
"shift": 1.15,
}
memory_usage_factor = 1.8 #TODO
unet_extra_config = {}
latent_format = latent_formats.Wan21
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.QwenImage(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models += [SVD_img2vid]

View File

@@ -43,6 +43,23 @@ class Qwen25_3BConfig:
mlp_activation = "silu"
qkv_bias = True
@dataclass
class Qwen25_7BVLI_Config:
vocab_size: int = 152064
hidden_size: int = 3584
intermediate_size: int = 18944
num_hidden_layers: int = 28
num_attention_heads: int = 28
num_key_value_heads: int = 4
max_position_embeddings: int = 128000
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = True
@dataclass
class Gemma2_2B_Config:
vocab_size: int = 256000
@@ -348,6 +365,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen25_7BVLI_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Gemma2_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()

View File

@@ -0,0 +1,71 @@
from transformers import Qwen2Tokenizer
from comfy import sd1_clip
import comfy.text_encoders.llama
import os
import torch
import numbers
class Qwen25_7BVLITokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=3584, embedding_key='qwen25_7b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer)
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class QwenImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
tok_pairs = token_weight_pairs["qwen25_7b"][0]
count_im_start = 0
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem):
if isinstance(elem, numbers.Integral):
if elem == 151644 and count_im_start < 2:
template_end = i
count_im_start += 1
if out.shape[1] > (template_end + 3):
if tok_pairs[template_end + 1][0] == 872:
if tok_pairs[template_end + 2][0] == 198:
template_end += 3
out = out[:, template_end:]
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
extra.pop("attention_mask") # attention mask is useless if no masked elements
return out, pooled, extra
def te(dtype_llama=None, llama_scaled_fp8=None):
class QwenImageTEModel_(QwenImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return QwenImageTEModel_

View File

@@ -5,3 +5,146 @@ from .api_registry import (
register_versions as register_versions,
get_all_versions as get_all_versions,
)
import asyncio
from dataclasses import asdict
from typing import Callable, Optional
def first_real_override(cls: type, name: str, *, base: type=None) -> Optional[Callable]:
"""Return the *callable* override of `name` visible on `cls`, or None if every
implementation up to (and including) `base` is the placeholder defined on `base`.
If base is not provided, it will assume cls has a GET_BASE_CLASS
"""
if base is None:
if not hasattr(cls, "GET_BASE_CLASS"):
raise ValueError("base is required if cls does not have a GET_BASE_CLASS; is this a valid ComfyNode subclass?")
base = cls.GET_BASE_CLASS()
base_attr = getattr(base, name, None)
if base_attr is None:
return None
base_func = base_attr.__func__
for c in cls.mro(): # NodeB, NodeA, ComfyNode, object …
if c is base: # reached the placeholder we're done
break
if name in c.__dict__: # first class that *defines* the attr
func = getattr(c, name).__func__
if func is not base_func: # real override
return getattr(cls, name) # bound to *cls*
return None
class _ComfyNodeInternal:
"""Class that all V3-based APIs inherit from for ComfyNode.
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
@classmethod
def GET_NODE_INFO_V1(cls):
...
class _NodeOutputInternal:
"""Class that all V3-based APIs inherit from for NodeOutput.
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
...
def as_pruned_dict(dataclass_obj):
'''Return dict of dataclass object with pruned None values.'''
return prune_dict(asdict(dataclass_obj))
def prune_dict(d: dict):
return {k: v for k,v in d.items() if v is not None}
def is_class(obj):
'''
Returns True if is a class type.
Returns False if is a class instance.
'''
return isinstance(obj, type)
def copy_class(cls: type) -> type:
'''
Copy a class and its attributes.
'''
if cls is None:
return None
cls_dict = {
k: v for k, v in cls.__dict__.items()
if k not in ('__dict__', '__weakref__', '__module__', '__doc__')
}
# new class
new_cls = type(
cls.__name__,
(cls,),
cls_dict
)
# metadata preservation
new_cls.__module__ = cls.__module__
new_cls.__doc__ = cls.__doc__
return new_cls
class classproperty(object):
def __init__(self, f):
self.f = f
def __get__(self, obj, owner):
return self.f(owner)
# NOTE: this was ai generated and validated by hand
def shallow_clone_class(cls, new_name=None):
'''
Shallow clone a class while preserving super() functionality.
'''
new_name = new_name or f"{cls.__name__}Clone"
# Include the original class in the bases to maintain proper inheritance
new_bases = (cls,) + cls.__bases__
return type(new_name, new_bases, dict(cls.__dict__))
# NOTE: this was ai generated and validated by hand
def lock_class(cls):
'''
Lock a class so that its top-levelattributes cannot be modified.
'''
# Locked instance __setattr__
def locked_instance_setattr(self, name, value):
raise AttributeError(
f"Cannot set attribute '{name}' on immutable instance of {type(self).__name__}"
)
# Locked metaclass
class LockedMeta(type(cls)):
def __setattr__(cls_, name, value):
raise AttributeError(
f"Cannot modify class attribute '{name}' on locked class '{cls_.__name__}'"
)
# Rebuild class with locked behavior
locked_dict = dict(cls.__dict__)
locked_dict['__setattr__'] = locked_instance_setattr
return LockedMeta(cls.__name__, cls.__bases__, locked_dict)
def make_locked_method_func(type_obj, func, class_clone):
"""
Returns a function that, when called with **inputs, will execute:
getattr(type_obj, func).__func__(lock_class(class_clone), **inputs)
Supports both synchronous and asynchronous methods.
"""
locked_class = lock_class(class_clone)
method = getattr(type_obj, func).__func__
# Check if the original method is async
if asyncio.iscoroutinefunction(method):
async def wrapped_async_func(**inputs):
return await method(locked_class, **inputs)
return wrapped_async_func
else:
def wrapped_func(**inputs):
return method(locked_class, **inputs)
return wrapped_func

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Type, TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton
@@ -7,6 +8,9 @@ from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
from comfy_api.latest._io import _IO as io #noqa: F401
from comfy_api.latest._ui import _UI as ui #noqa: F401
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple
from PIL import Image
@@ -72,6 +76,19 @@ class ComfyAPI_latest(ComfyAPIBase):
execution: Execution
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
Called when an extension is loaded.
This should be used to initialize any global resources neeeded by the extension.
"""
@abstractmethod
async def get_node_list(self) -> list[type[io.ComfyNode]]:
"""
Returns a list of nodes that this extension provides.
"""
class Input:
Image = ImageInput
Audio = AudioInput
@@ -103,4 +120,5 @@ __all__ = [
"Input",
"InputImpl",
"Types",
"ComfyExtension",
]

1618
comfy_api/latest/_io.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,72 @@
from __future__ import annotations
import comfy.utils
import folder_paths
import logging
from abc import ABC, abstractmethod
from typing import Any
import torch
class ResourceKey(ABC):
Type = Any
def __init__(self):
...
class TorchDictFolderFilename(ResourceKey):
'''Key for requesting a torch file via file_name from a folder category.'''
Type = dict[str, torch.Tensor]
def __init__(self, folder_name: str, file_name: str):
self.folder_name = folder_name
self.file_name = file_name
def __hash__(self):
return hash((self.folder_name, self.file_name))
def __eq__(self, other: object) -> bool:
if not isinstance(other, TorchDictFolderFilename):
return False
return self.folder_name == other.folder_name and self.file_name == other.file_name
def __str__(self):
return f"{self.folder_name} -> {self.file_name}"
class Resources(ABC):
def __init__(self):
...
@abstractmethod
def get(self, key: ResourceKey, default: Any=...) -> Any:
pass
class ResourcesLocal(Resources):
def __init__(self):
super().__init__()
self.local_resources: dict[ResourceKey, Any] = {}
def get(self, key: ResourceKey, default: Any=...) -> Any:
cached = self.local_resources.get(key, None)
if cached is not None:
logging.info(f"Using cached resource '{key}'")
return cached
logging.info(f"Loading resource '{key}'")
to_return = None
if isinstance(key, TorchDictFolderFilename):
if default is ...:
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
else:
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
if full_path is not None:
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
if to_return is not None:
self.local_resources[key] = to_return
return to_return
if default is not ...:
return default
raise Exception(f"Unsupported resource key type: {type(key)}")
class _RESOURCES:
ResourceKey = ResourceKey
TorchDictFolderFilename = TorchDictFolderFilename
Resources = Resources
ResourcesLocal = ResourcesLocal

457
comfy_api/latest/_ui.py Normal file
View File

@@ -0,0 +1,457 @@
from __future__ import annotations
import json
import os
import random
from io import BytesIO
from typing import Type
import av
import numpy as np
import torch
import torchaudio
from PIL import Image as PILImage
from PIL.PngImagePlugin import PngInfo
import folder_paths
# used for image preview
from comfy.cli_args import args
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
class SavedResult(dict):
def __init__(self, filename: str, subfolder: str, type: FolderType):
super().__init__(filename=filename, subfolder=subfolder,type=type.value)
@property
def filename(self) -> str:
return self["filename"]
@property
def subfolder(self) -> str:
return self["subfolder"]
@property
def type(self) -> FolderType:
return FolderType(self["type"])
class SavedImages(_UIOutput):
"""A UI output class to represent one or more saved images, potentially animated."""
def __init__(self, results: list[SavedResult], is_animated: bool = False):
super().__init__()
self.results = results
self.is_animated = is_animated
def as_dict(self) -> dict:
data = {"images": self.results}
if self.is_animated:
data["animated"] = (True,)
return data
class SavedAudios(_UIOutput):
"""UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus)."""
def __init__(self, results: list[SavedResult]):
super().__init__()
self.results = results
def as_dict(self) -> dict:
return {"audio": self.results}
def _get_directory_by_folder_type(folder_type: FolderType) -> str:
if folder_type == FolderType.input:
return folder_paths.get_input_directory()
if folder_type == FolderType.output:
return folder_paths.get_output_directory()
return folder_paths.get_temp_directory()
class ImageSaveHelper:
"""A helper class with static methods to handle image saving and metadata."""
@staticmethod
def _convert_tensor_to_pil(image_tensor: torch.Tensor) -> PILImage.Image:
"""Converts a single torch tensor to a PIL Image."""
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
@staticmethod
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo."""
if args.disable_metadata or cls is None or not cls.hidden:
return None
metadata = PngInfo()
if cls.hidden.prompt:
metadata.add_text("prompt", json.dumps(cls.hidden.prompt))
if cls.hidden.extra_pnginfo:
for x in cls.hidden.extra_pnginfo:
metadata.add_text(x, json.dumps(cls.hidden.extra_pnginfo[x]))
return metadata
@staticmethod
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
if args.disable_metadata or cls is None or not cls.hidden:
return None
metadata = PngInfo()
if cls.hidden.prompt:
metadata.add(
b"comf",
"prompt".encode("latin-1", "strict")
+ b"\0"
+ json.dumps(cls.hidden.prompt).encode("latin-1", "strict"),
after_idat=True,
)
if cls.hidden.extra_pnginfo:
for x in cls.hidden.extra_pnginfo:
metadata.add(
b"comf",
x.encode("latin-1", "strict")
+ b"\0"
+ json.dumps(cls.hidden.extra_pnginfo[x]).encode("latin-1", "strict"),
after_idat=True,
)
return metadata
@staticmethod
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
"""Creates EXIF metadata bytes for WebP images."""
exif_data = pil_image.getexif()
if args.disable_metadata or cls is None or cls.hidden is None:
return exif_data
if cls.hidden.prompt is not None:
exif_data[0x0110] = "prompt:{}".format(json.dumps(cls.hidden.prompt)) # EXIF 0x0110 = Model
if cls.hidden.extra_pnginfo is not None:
inital_exif_tag = 0x010F # EXIF 0x010f = Make
for key, value in cls.hidden.extra_pnginfo.items():
exif_data[inital_exif_tag] = "{}:{}".format(key, json.dumps(value))
inital_exif_tag -= 1
return exif_data
@staticmethod
def save_images(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
) -> list[SavedResult]:
"""Saves a batch of images as individual PNG files."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
)
results = []
metadata = ImageSaveHelper._create_png_metadata(cls)
for batch_number, image_tensor in enumerate(images):
img = ImageSaveHelper._convert_tensor_to_pil(image_tensor)
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level)
results.append(SavedResult(file, subfolder, folder_type))
counter += 1
return results
@staticmethod
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
"""Saves a batch of images and returns a UI object for the node output."""
return SavedImages(
ImageSaveHelper.save_images(
images,
filename_prefix=filename_prefix,
folder_type=FolderType.output,
cls=cls,
compress_level=compress_level,
)
)
@staticmethod
def save_animated_png(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedResult:
"""Saves a batch of images as a single animated PNG."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
)
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
metadata = ImageSaveHelper._create_animated_png_metadata(cls)
file = f"{filename}_{counter:05}_.png"
save_path = os.path.join(full_output_folder, file)
pil_images[0].save(
save_path,
pnginfo=metadata,
compress_level=compress_level,
save_all=True,
duration=int(1000.0 / fps),
append_images=pil_images[1:],
)
return SavedResult(file, subfolder, folder_type)
@staticmethod
def get_save_animated_png_ui(
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedImages:
"""Saves an animated PNG and returns a UI object for the node output."""
result = ImageSaveHelper.save_animated_png(
images,
filename_prefix=filename_prefix,
folder_type=FolderType.output,
cls=cls,
fps=fps,
compress_level=compress_level,
)
return SavedImages([result], is_animated=len(images) > 1)
@staticmethod
def save_animated_webp(
images,
filename_prefix: str,
folder_type: FolderType,
cls: Type[ComfyNode] | None,
fps: float,
lossless: bool,
quality: int,
method: int,
) -> SavedResult:
"""Saves a batch of images as a single animated WebP."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
)
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
pil_exif = ImageSaveHelper._create_webp_metadata(pil_images[0], cls)
file = f"{filename}_{counter:05}_.webp"
pil_images[0].save(
os.path.join(full_output_folder, file),
save_all=True,
duration=int(1000.0 / fps),
append_images=pil_images[1:],
exif=pil_exif,
lossless=lossless,
quality=quality,
method=method,
)
return SavedResult(file, subfolder, folder_type)
@staticmethod
def get_save_animated_webp_ui(
images,
filename_prefix: str,
cls: Type[ComfyNode] | None,
fps: float,
lossless: bool,
quality: int,
method: int,
) -> SavedImages:
"""Saves an animated WebP and returns a UI object for the node output."""
result = ImageSaveHelper.save_animated_webp(
images,
filename_prefix=filename_prefix,
folder_type=FolderType.output,
cls=cls,
fps=fps,
lossless=lossless,
quality=quality,
method=method,
)
return SavedImages([result], is_animated=len(images) > 1)
class AudioSaveHelper:
"""A helper class with static methods to handle audio saving and metadata."""
_OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
@staticmethod
def save_audio(
audio: dict,
filename_prefix: str,
folder_type: FolderType,
cls: Type[ComfyNode] | None,
format: str = "flac",
quality: str = "128k",
) -> list[SavedResult]:
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, _get_directory_by_folder_type(folder_type)
)
metadata = {}
if not args.disable_metadata and cls is not None:
if cls.hidden.prompt is not None:
metadata["prompt"] = json.dumps(cls.hidden.prompt)
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
results = []
for batch_number, waveform in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
output_path = os.path.join(full_output_folder, file)
# Use original sample rate initially
sample_rate = audio["sample_rate"]
# Handle Opus sample rate requirements
if format == "opus":
if sample_rate > 48000:
sample_rate = 48000
elif sample_rate not in AudioSaveHelper._OPUS_RATES:
# Find the next highest supported rate
for rate in sorted(AudioSaveHelper._OPUS_RATES):
if rate > sample_rate:
sample_rate = rate
break
if sample_rate not in AudioSaveHelper._OPUS_RATES: # Fallback if still not supported
sample_rate = 48000
# Resample if necessary
if sample_rate != audio["sample_rate"]:
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
# Create output with specified format
output_buffer = BytesIO()
output_container = av.open(output_buffer, mode="w", format=format)
# Set metadata on the container
for key, value in metadata.items():
output_container.metadata[key] = value
# Set up the output stream with appropriate properties
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
out_stream.bit_rate = 96000
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "192k":
out_stream.bit_rate = 192000
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
if quality == "V0":
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "320k":
out_stream.bit_rate = 320000
else: # format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate)
frame = av.AudioFrame.from_ndarray(
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
format="flt",
layout="mono" if waveform.shape[0] == 1 else "stereo",
)
frame.sample_rate = sample_rate
frame.pts = 0
output_container.mux(out_stream.encode(frame))
# Flush encoder
output_container.mux(out_stream.encode(None))
# Close containers
output_container.close()
# Write the output to file
output_buffer.seek(0)
with open(output_path, "wb") as f:
f.write(output_buffer.getbuffer())
results.append(SavedResult(file, subfolder, folder_type))
counter += 1
return results
@staticmethod
def get_save_audio_ui(
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
) -> SavedAudios:
"""Save and instantly wrap for UI."""
return SavedAudios(
AudioSaveHelper.save_audio(
audio,
filename_prefix=filename_prefix,
folder_type=FolderType.output,
cls=cls,
format=format,
quality=quality,
)
)
class PreviewImage(_UIOutput):
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
self.values = ImageSaveHelper.save_images(
image,
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
folder_type=FolderType.temp,
cls=cls,
compress_level=1,
)
self.animated = animated
def as_dict(self):
return {
"images": self.values,
"animated": (self.animated,)
}
class PreviewMask(PreviewImage):
def __init__(self, mask: PreviewMask.Type, animated: bool=False, cls: ComfyNode=None, **kwargs):
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
super().__init__(preview, animated, cls, **kwargs)
class PreviewAudio(_UIOutput):
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
self.values = AudioSaveHelper.save_audio(
audio,
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
folder_type=FolderType.temp,
cls=cls,
format="flac",
quality="128k",
)
def as_dict(self) -> dict:
return {"audio": self.values}
class PreviewVideo(_UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
return {"images": self.values, "animated": (True,)}
class PreviewUI3D(_UIOutput):
def __init__(self, model_file, camera_info, **kwargs):
self.model_file = model_file
self.camera_info = camera_info
def as_dict(self):
return {"result": [self.model_file, self.camera_info]}
class PreviewText(_UIOutput):
def __init__(self, value: str, **kwargs):
self.value = value
def as_dict(self):
return {"text": (self.value,)}
class _UI:
SavedResult = SavedResult
SavedImages = SavedImages
SavedAudios = SavedAudios
ImageSaveHelper = ImageSaveHelper
AudioSaveHelper = AudioSaveHelper
PreviewImage = PreviewImage
PreviewMask = PreviewMask
PreviewAudio = PreviewAudio
PreviewVideo = PreviewVideo
PreviewUI3D = PreviewUI3D
PreviewText = PreviewText

View File

@@ -6,6 +6,7 @@ from comfy_api.latest import (
)
from typing import Type, TYPE_CHECKING
from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest import io, ui, ComfyExtension #noqa: F401
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
@@ -40,4 +41,5 @@ __all__ = [
"Input",
"InputImpl",
"Types",
"ComfyExtension",
]

File diff suppressed because it is too large Load Diff

View File

@@ -127,7 +127,7 @@ class TripoTextToModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task')
prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024)
negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024)
model_version: Optional[TripoModelVersion] = TripoModelVersion.V2_5
model_version: Optional[TripoModelVersion] = TripoModelVersion.v2_5_20250123
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')

View File

@@ -8,10 +8,10 @@ from typing import Optional
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis import (
Veo2GenVidRequest,
Veo2GenVidResponse,
Veo2GenVidPollRequest,
Veo2GenVidPollResponse
VeoGenVidRequest,
VeoGenVidResponse,
VeoGenVidPollRequest,
VeoGenVidPollResponse
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
@@ -35,7 +35,7 @@ def convert_image_to_base64(image: torch.Tensor):
return tensor_to_base64_string(scaled_image)
def get_video_url_from_response(poll_response: Veo2GenVidPollResponse) -> Optional[str]:
def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optional[str]:
if (
poll_response.response
and hasattr(poll_response.response, "videos")
@@ -130,6 +130,14 @@ class VeoVideoGenerationNode(ComfyNodeABC):
"default": None,
"tooltip": "Optional reference image to guide video generation",
}),
"model": (
IO.COMBO,
{
"options": ["veo-2.0-generate-001"],
"default": "veo-2.0-generate-001",
"tooltip": "Veo 2 model to use for video generation",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
@@ -141,7 +149,7 @@ class VeoVideoGenerationNode(ComfyNodeABC):
RETURN_TYPES = (IO.VIDEO,)
FUNCTION = "generate_video"
CATEGORY = "api node/video/Veo"
DESCRIPTION = "Generates videos from text prompts using Google's Veo API"
DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API"
API_NODE = True
def generate_video(
@@ -154,6 +162,8 @@ class VeoVideoGenerationNode(ComfyNodeABC):
person_generation="ALLOW",
seed=0,
image=None,
model="veo-2.0-generate-001",
generate_audio=False,
unique_id: Optional[str] = None,
**kwargs,
):
@@ -188,16 +198,19 @@ class VeoVideoGenerationNode(ComfyNodeABC):
parameters["negativePrompt"] = negative_prompt
if seed > 0:
parameters["seed"] = seed
# Only add generateAudio for Veo 3 models
if "veo-3.0" in model:
parameters["generateAudio"] = generate_audio
# Initial request to start video generation
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/veo/generate",
path=f"/proxy/veo/{model}/generate",
method=HttpMethod.POST,
request_model=Veo2GenVidRequest,
response_model=Veo2GenVidResponse
request_model=VeoGenVidRequest,
response_model=VeoGenVidResponse
),
request=Veo2GenVidRequest(
request=VeoGenVidRequest(
instances=instances,
parameters=parameters
),
@@ -223,16 +236,16 @@ class VeoVideoGenerationNode(ComfyNodeABC):
# Define the polling operation
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/veo/poll",
path=f"/proxy/veo/{model}/poll",
method=HttpMethod.POST,
request_model=Veo2GenVidPollRequest,
response_model=Veo2GenVidPollResponse
request_model=VeoGenVidPollRequest,
response_model=VeoGenVidPollResponse
),
completed_statuses=["completed"],
failed_statuses=[], # No failed statuses, we'll handle errors after polling
status_extractor=status_extractor,
progress_extractor=progress_extractor,
request=Veo2GenVidPollRequest(
request=VeoGenVidPollRequest(
operationName=operation_name
),
auth_kwargs=kwargs,
@@ -298,11 +311,64 @@ class VeoVideoGenerationNode(ComfyNodeABC):
return (VideoFromFile(video_io),)
# Register the node
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
"""
Generates videos from text prompts using Google's Veo 3 API.
Supported models:
- veo-3.0-generate-001
- veo-3.0-fast-generate-001
This node extends the base Veo node with Veo 3 specific features including
audio generation and fixed 8-second duration.
"""
@classmethod
def INPUT_TYPES(s):
parent_input = super().INPUT_TYPES()
# Update model options for Veo 3
parent_input["optional"]["model"] = (
IO.COMBO,
{
"options": ["veo-3.0-generate-001", "veo-3.0-fast-generate-001"],
"default": "veo-3.0-generate-001",
"tooltip": "Veo 3 model to use for video generation",
},
)
# Add generateAudio parameter
parent_input["optional"]["generate_audio"] = (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Generate audio for the video. Supported by all Veo 3 models.",
}
)
# Update duration constraints for Veo 3 (only 8 seconds supported)
parent_input["optional"]["duration_seconds"] = (
IO.INT,
{
"default": 8,
"min": 8,
"max": 8,
"step": 1,
"display": "number",
"tooltip": "Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
},
)
return parent_input
# Register the nodes
NODE_CLASS_MAPPINGS = {
"VeoVideoGenerationNode": VeoVideoGenerationNode,
"Veo3VideoGenerationNode": Veo3VideoGenerationNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VeoVideoGenerationNode": "Google Veo2 Video Generation",
"VeoVideoGenerationNode": "Google Veo 2 Video Generation",
"Veo3VideoGenerationNode": "Google Veo 3 Video Generation",
}

View File

@@ -4,9 +4,12 @@ from typing import Type, Literal
import nodes
import asyncio
import inspect
from comfy_execution.graph_utils import is_link
from comfy_execution.graph_utils import is_link, ExecutionBlocker
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
# NOTE: ExecutionBlocker code got moved to graph_utils.py to prevent torch being imported too soon during unit tests
ExecutionBlocker = ExecutionBlocker
class DependencyCycleError(Exception):
pass
@@ -294,21 +297,3 @@ class ExecutionList(TopologicalSort):
del blocked_by[node_id]
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
return list(blocked_by.keys())
class ExecutionBlocker:
"""
Return this from a node and any users will be blocked with the given error message.
If the message is None, execution will be blocked silently instead.
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
possible, a lazy input will be more efficient and have a better user experience.
This functionality is useful in two cases:
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
lazy evaluation to let it conditionally disable itself.)
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
"""
def __init__(self, message):
self.message = message

View File

@@ -137,3 +137,19 @@ def add_graph_prefix(graph, outputs, prefix):
return new_graph, tuple(new_outputs)
class ExecutionBlocker:
"""
Return this from a node and any users will be blocked with the given error message.
If the message is None, execution will be blocked silently instead.
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
possible, a lazy input will be more efficient and have a better user experience.
This functionality is useful in two cases:
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
lazy evaluation to let it conditionally disable itself.)
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
"""
def __init__(self, message):
self.message = message

View File

@@ -149,6 +149,7 @@ class WanFirstLastFrameToVideo:
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
clip_vision_output = None
if clip_vision_start_image is not None:
clip_vision_output = clip_vision_start_image

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.47"
__version__ = "0.3.49"

View File

@@ -7,7 +7,7 @@ import threading
import time
import traceback
from enum import Enum
from typing import List, Literal, NamedTuple, Optional
from typing import List, Literal, NamedTuple, Optional, Union
import asyncio
import torch
@@ -32,6 +32,8 @@ from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io
class ExecutionResult(Enum):
@@ -56,7 +58,15 @@ class IsChangedCache:
node = self.dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if not hasattr(class_def, "IS_CHANGED"):
has_is_changed = False
is_changed_name = None
if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None:
has_is_changed = True
is_changed_name = "fingerprint_inputs"
elif hasattr(class_def, "IS_CHANGED"):
has_is_changed = True
is_changed_name = "IS_CHANGED"
if not has_is_changed:
self.is_changed[node_id] = False
return self.is_changed[node_id]
@@ -65,9 +75,9 @@ class IsChangedCache:
return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
try:
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, "IS_CHANGED")
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
is_changed = await resolve_map_node_over_list_results(is_changed)
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except Exception as e:
@@ -126,9 +136,14 @@ class CacheSet:
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
is_v3 = issubclass(class_def, _ComfyNodeInternal)
if is_v3:
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
else:
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
missing_keys = {}
hidden_inputs_v3 = {}
for x in inputs:
input_data = inputs[x]
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
@@ -153,22 +168,37 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
elif input_category is not None:
input_data_all[x] = [input_data]
if "hidden" in valid_inputs:
h = valid_inputs["hidden"]
for x in h:
if h[x] == "PROMPT":
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
if h[x] == "DYNPROMPT":
input_data_all[x] = [dynprompt]
if h[x] == "EXTRA_PNGINFO":
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
if h[x] == "UNIQUE_ID":
input_data_all[x] = [unique_id]
if h[x] == "AUTH_TOKEN_COMFY_ORG":
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
return input_data_all, missing_keys
if is_v3:
if schema.hidden:
if io.Hidden.prompt in schema.hidden:
hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
if io.Hidden.dynprompt in schema.hidden:
hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt
if io.Hidden.extra_pnginfo in schema.hidden:
hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
if io.Hidden.unique_id in schema.hidden:
hidden_inputs_v3[io.Hidden.unique_id] = unique_id
if io.Hidden.auth_token_comfy_org in schema.hidden:
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
if io.Hidden.api_key_comfy_org in schema.hidden:
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
else:
if "hidden" in valid_inputs:
h = valid_inputs["hidden"]
for x in h:
if h[x] == "PROMPT":
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
if h[x] == "DYNPROMPT":
input_data_all[x] = [dynprompt]
if h[x] == "EXTRA_PNGINFO":
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
if h[x] == "UNIQUE_ID":
input_data_all[x] = [unique_id]
if h[x] == "AUTH_TOKEN_COMFY_ORG":
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
return input_data_all, missing_keys, hidden_inputs_v3
map_node_over_list = None #Don't hook this please
@@ -184,7 +214,7 @@ async def resolve_map_node_over_list_results(results):
raise exc
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
# check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
@@ -214,7 +244,22 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
if execution_block is None:
if pre_execute_cb is not None and index is not None:
pre_execute_cb(index)
f = getattr(obj, func)
# V3
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
# if is just a class, then assign no resources or state, just create clone
if is_class(obj):
type_obj = obj
obj.VALIDATE_CLASS()
class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs)
# otherwise, use class instance to populate/reuse some fields
else:
type_obj = type(obj)
type_obj.VALIDATE_CLASS()
class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs)
f = make_locked_method_func(type_obj, func, class_clone)
# V1
else:
f = getattr(obj, func)
if inspect.iscoroutinefunction(f):
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
with CurrentNodeContext(prompt_id, unique_id, list_index):
@@ -266,8 +311,8 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results])
return output
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task:
return return_values, {}, False, has_pending_task
@@ -298,6 +343,26 @@ def get_output_from_returns(return_values, obj):
result = tuple([result] * len(obj.RETURN_TYPES))
results.append(result)
subgraph_results.append((None, result))
elif isinstance(r, _NodeOutputInternal):
# V3
if r.ui is not None:
if isinstance(r.ui, dict):
uis.append(r.ui)
else:
uis.append(r.ui.as_dict())
if r.expand is not None:
has_subgraph = True
new_graph = r.expand
result = r.result
if r.block_execution is not None:
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
subgraph_results.append((new_graph, result))
elif r.result is not None:
result = r.result
if r.block_execution is not None:
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
results.append(result)
subgraph_results.append((None, result))
else:
if isinstance(r, ExecutionBlocker):
r = tuple([r] * len(obj.RETURN_TYPES))
@@ -381,7 +446,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
has_subgraph = False
else:
get_progress_state().start_progress(unique_id)
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@@ -391,8 +456,12 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
obj = class_def()
caches.objects.set(unique_id, obj)
if hasattr(obj, "check_lazy_status"):
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True)
if issubclass(class_def, _ComfyNodeInternal):
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
else:
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
if lazy_status_present:
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
required_inputs = await resolve_map_node_over_list_results(required_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
@@ -424,7 +493,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
def pre_execute_cb(call_index):
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id)
@@ -672,8 +741,14 @@ async def validate_inputs(prompt_id, prompt, item, validated):
validate_function_inputs = []
validate_has_kwargs = False
if hasattr(obj_class, "VALIDATE_INPUTS"):
argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS)
if issubclass(obj_class, _ComfyNodeInternal):
validate_function_name = "validate_inputs"
validate_function = first_real_override(obj_class, validate_function_name)
else:
validate_function_name = "VALIDATE_INPUTS"
validate_function = getattr(obj_class, validate_function_name, None)
if validate_function is not None:
argspec = inspect.getfullargspec(validate_function)
validate_function_inputs = argspec.args
validate_has_kwargs = argspec.varkw is not None
received_types = {}
@@ -848,7 +923,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
continue
if len(validate_function_inputs) > 0 or validate_has_kwargs:
input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs or validate_has_kwargs:
@@ -856,8 +931,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS")
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
ret = await resolve_map_node_over_list_results(ret)
for x in input_filtered:
for i, r in enumerate(ret):
@@ -891,7 +965,7 @@ def full_type_name(klass):
return klass.__qualname__
return module + '.' + klass.__qualname__
async def validate_prompt(prompt_id, prompt):
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
outputs = set()
for x in prompt:
if 'class_type' not in prompt[x]:
@@ -915,7 +989,8 @@ async def validate_prompt(prompt_id, prompt):
return (False, error, [], {})
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
outputs.add(x)
if partial_execution_list is None or x in partial_execution_list:
outputs.add(x)
if len(outputs) == 0:
error = {

View File

@@ -6,6 +6,7 @@ import os
import sys
import json
import hashlib
import inspect
import traceback
import math
import time
@@ -29,6 +30,7 @@ import comfy.controlnet
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
from comfy_api.internal import register_versions, ComfyAPIWithVersion
from comfy_api.version_list import supported_versions
from comfy_api.latest import io, ComfyExtension
import comfy.clip_vision
@@ -923,7 +925,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@@ -2152,6 +2154,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
if os.path.isdir(web_dir):
EXTENSION_WEB_DIRS[module_name] = web_dir
# V1 node definition
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
if name not in ignore:
@@ -2160,8 +2163,38 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
return True
# V3 Extension Definition
elif hasattr(module, "comfy_entrypoint"):
entrypoint = getattr(module, "comfy_entrypoint")
if not callable(entrypoint):
logging.warning(f"comfy_entrypoint in {module_path} is not callable, skipping.")
return False
try:
if inspect.iscoroutinefunction(entrypoint):
extension = await entrypoint()
else:
extension = entrypoint()
if not isinstance(extension, ComfyExtension):
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
return False
node_list = await extension.get_node_list()
if not isinstance(node_list, list):
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
return False
for node_cls in node_list:
node_cls: io.ComfyNode
schema = node_cls.GET_SCHEMA()
if schema.node_id not in ignore:
NODE_CLASS_MAPPINGS[schema.node_id] = node_cls
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
if schema.display_name is not None:
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
return True
except Exception as e:
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")
return False
else:
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or NODES_LIST (need one).")
return False
except Exception as e:
logging.warning(traceback.format_exc())
@@ -2286,7 +2319,7 @@ async def init_builtin_extra_nodes():
"nodes_string.py",
"nodes_camera_trajectory.py",
"nodes_edit_model.py",
"nodes_tcfg.py"
"nodes_tcfg.py",
]
import_failed = []

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.47"
version = "0.3.49"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.23.4
comfyui-workflow-templates==0.1.41
comfyui-workflow-templates==0.1.51
comfyui-embedded-docs==0.2.4
torch
torchsde

View File

@@ -30,6 +30,7 @@ from comfy_api import feature_flags
import node_helpers
from comfyui_version import __version__
from app.frontend_management import FrontendManager
from comfy_api.internal import _ComfyNodeInternal
from app.user_manager import UserManager
from app.model_manager import ModelFileManager
@@ -591,6 +592,8 @@ class PromptServer():
def node_info(node_class):
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
if issubclass(obj_class, _ComfyNodeInternal):
return obj_class.GET_NODE_INFO_V1()
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
@@ -681,7 +684,12 @@ class PromptServer():
if "prompt" in json_data:
prompt = json_data["prompt"]
prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
valid = await execution.validate_prompt(prompt_id, prompt)
partial_execution_targets = None
if "partial_execution_targets" in json_data:
partial_execution_targets = json_data["partial_execution_targets"]
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
extra_data = {}
if "extra_data" in json_data:
extra_data = json_data["extra_data"]

View File

@@ -7,7 +7,7 @@ import subprocess
from pytest import fixture
from comfy_execution.graph_utils import GraphBuilder
from tests.inference.test_execution import ComfyClient
from tests.inference.test_execution import ComfyClient, run_warmup
@pytest.mark.execution
@@ -24,6 +24,7 @@ class TestAsyncNodes:
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
'--cpu',
]
use_lru, lru_size = request.param
if use_lru:
@@ -82,6 +83,9 @@ class TestAsyncNodes:
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test that multiple async nodes execute in parallel."""
# Warmup execution to ensure server is fully initialized
run_warmup(client)
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@@ -148,6 +152,9 @@ class TestAsyncNodes:
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes with lazy evaluation."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_lazy")
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
@@ -305,6 +312,9 @@ class TestAsyncNodes:
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
"""Test that async nodes are properly cached."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_cache")
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
@@ -324,6 +334,9 @@ class TestAsyncNodes:
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes within dynamically generated prompts."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_dynamic")
g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)

View File

@@ -15,10 +15,18 @@ import urllib.parse
import urllib.error
from comfy_execution.graph_utils import GraphBuilder, Node
def run_warmup(client, prefix="warmup"):
"""Run a simple workflow to warm up the server."""
warmup_g = GraphBuilder(prefix=prefix)
warmup_image = warmup_g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1)
warmup_g.node("PreviewImage", images=warmup_image.out(0))
client.run(warmup_g)
class RunResult:
def __init__(self, prompt_id: str):
self.outputs: Dict[str,Dict] = {}
self.runs: Dict[str,bool] = {}
self.cached: Dict[str,bool] = {}
self.prompt_id: str = prompt_id
def get_output(self, node: Node):
@@ -27,6 +35,13 @@ class RunResult:
def did_run(self, node: Node):
return self.runs.get(node.id, False)
def was_cached(self, node: Node):
return self.cached.get(node.id, False)
def was_executed(self, node: Node):
"""Returns True if node was either run or cached"""
return self.did_run(node) or self.was_cached(node)
def get_images(self, node: Node):
output = self.get_output(node)
if output is None:
@@ -51,8 +66,10 @@ class ComfyClient:
ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
self.ws = ws
def queue_prompt(self, prompt):
def queue_prompt(self, prompt, partial_execution_targets=None):
p = {"prompt": prompt, "client_id": self.client_id}
if partial_execution_targets is not None:
p["partial_execution_targets"] = partial_execution_targets
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
@@ -70,13 +87,13 @@ class ComfyClient:
def set_test_name(self, name):
self.test_name = name
def run(self, graph):
def run(self, graph, partial_execution_targets=None):
prompt = graph.finalize()
for node in graph.nodes.values():
if node.class_type == 'SaveImage':
node.inputs['filename_prefix'] = self.test_name
prompt_id = self.queue_prompt(prompt)['prompt_id']
prompt_id = self.queue_prompt(prompt, partial_execution_targets)['prompt_id']
result = RunResult(prompt_id)
while True:
out = self.ws.recv()
@@ -92,7 +109,10 @@ class ComfyClient:
elif message['type'] == 'execution_error':
raise Exception(message['data'])
elif message['type'] == 'execution_cached':
pass # Probably want to store this off for testing
if message['data']['prompt_id'] == prompt_id:
cached_nodes = message['data'].get('nodes', [])
for node_id in cached_nodes:
result.cached[node_id] = True
history = self.get_history(prompt_id)[prompt_id]
for node_id in history['outputs']:
@@ -130,6 +150,7 @@ class TestExecution:
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
'--cpu',
]
use_lru, lru_size = request.param
if use_lru:
@@ -498,12 +519,15 @@ class TestExecution:
assert not result.did_run(test_node), "The execution should have been cached"
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
# Warmup execution to ensure server is fully initialized
run_warmup(client)
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create sleep nodes for each duration
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8)
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9)
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.9)
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=3.1)
sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0)
# Add outputs to verify the execution
@@ -515,10 +539,9 @@ class TestExecution:
result = client.run(g)
elapsed_time = time.time() - start_time
# The test should take around 0.4 seconds (the longest sleep duration)
# plus some overhead, but definitely less than the sum of all sleeps (0.9s)
# We'll allow for up to 0.8s total to account for overhead
assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s"
# The test should take around 3.0 seconds (the longest sleep duration)
# plus some overhead, but definitely less than the sum of all sleeps (9.0s)
assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s"
# Verify that all nodes executed
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
@@ -526,6 +549,9 @@ class TestExecution:
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
# Warmup execution to ensure server is fully initialized
run_warmup(client)
g = builder
# Create input images with different values
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@@ -537,9 +563,9 @@ class TestExecution:
image1=image1.out(0),
image2=image2.out(0),
image3=image3.out(0),
sleep1=0.4,
sleep2=0.5,
sleep3=0.6)
sleep1=4.8,
sleep2=4.9,
sleep3=5.0)
output = g.node("SaveImage", images=parallel_sleep.out(0))
start_time = time.time()
@@ -548,7 +574,7 @@ class TestExecution:
# Similar to the previous test, expect parallel execution of the sleep nodes
# which should complete in less than the sum of all sleeps
assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s"
assert elapsed_time < 10.0, f"Expansion execution took {elapsed_time}s, expected less than 5.5s"
# Verify the parallel sleep node executed
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
@@ -585,3 +611,151 @@ class TestExecution:
assert len(images) == 2, "Should have 2 images"
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"
# Output nodes included in the partial execution list are executed
def test_partial_execution_included_outputs(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Create two separate output nodes
output1 = g.node("SaveImage", images=input1.out(0))
output2 = g.node("SaveImage", images=input2.out(0))
# Run with partial execution targeting only output1
result = client.run(g, partial_execution_targets=[output1.id])
assert result.was_executed(input1), "Input1 should have been executed (run or cached)"
assert result.was_executed(output1), "Output1 should have been executed (run or cached)"
assert not result.did_run(input2), "Input2 should not have run"
assert not result.did_run(output2), "Output2 should not have run"
# Verify only output1 produced results
assert len(result.get_images(output1)) == 1, "Output1 should have produced an image"
assert len(result.get_images(output2)) == 0, "Output2 should not have produced an image"
# Output nodes NOT included in the partial execution list are NOT executed
def test_partial_execution_excluded_outputs(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
# Create three output nodes
output1 = g.node("SaveImage", images=input1.out(0))
output2 = g.node("SaveImage", images=input2.out(0))
output3 = g.node("SaveImage", images=input3.out(0))
# Run with partial execution targeting only output1 and output3
result = client.run(g, partial_execution_targets=[output1.id, output3.id])
assert result.was_executed(input1), "Input1 should have been executed"
assert result.was_executed(input3), "Input3 should have been executed"
assert result.was_executed(output1), "Output1 should have been executed"
assert result.was_executed(output3), "Output3 should have been executed"
assert not result.did_run(input2), "Input2 should not have run"
assert not result.did_run(output2), "Output2 should not have run"
# Output nodes NOT in list ARE executed if necessary for nodes that are in the list
def test_partial_execution_dependencies(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create a processing chain with an OUTPUT_NODE that has socket outputs
output_with_socket = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=2.0)
# Create another node that depends on the output_with_socket
dependent_node = g.node("TestLazyMixImages",
image1=output_with_socket.out(0),
image2=input1.out(0),
mask=g.node("StubMask", value=0.5, height=512, width=512, batch_size=1).out(0))
# Create the final output
final_output = g.node("SaveImage", images=dependent_node.out(0))
# Run with partial execution targeting only the final output
result = client.run(g, partial_execution_targets=[final_output.id])
# All nodes should have been executed because they're dependencies
assert result.was_executed(input1), "Input1 should have been executed"
assert result.was_executed(output_with_socket), "Output with socket should have been executed (dependency)"
assert result.was_executed(dependent_node), "Dependent node should have been executed"
assert result.was_executed(final_output), "Final output should have been executed"
# Lazy execution works with partial execution
def test_partial_execution_with_lazy_nodes(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
# Create masks that will trigger different lazy execution paths
mask1 = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) # Will only need image1
mask2 = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) # Will need both images
# Create two lazy mix nodes
lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask1.out(0))
lazy_mix2 = g.node("TestLazyMixImages", image1=input2.out(0), image2=input3.out(0), mask=mask2.out(0))
output1 = g.node("SaveImage", images=lazy_mix1.out(0))
output2 = g.node("SaveImage", images=lazy_mix2.out(0))
# Run with partial execution targeting only output1
result = client.run(g, partial_execution_targets=[output1.id])
# For output1 path - only input1 should run due to lazy evaluation (mask=0.0)
assert result.was_executed(input1), "Input1 should have been executed"
assert not result.did_run(input2), "Input2 should not have run (lazy evaluation)"
assert result.was_executed(mask1), "Mask1 should have been executed"
assert result.was_executed(lazy_mix1), "Lazy mix1 should have been executed"
assert result.was_executed(output1), "Output1 should have been executed"
# Nothing from output2 path should run
assert not result.did_run(input3), "Input3 should not have run"
assert not result.did_run(mask2), "Mask2 should not have run"
assert not result.did_run(lazy_mix2), "Lazy mix2 should not have run"
assert not result.did_run(output2), "Output2 should not have run"
# Multiple OUTPUT_NODEs with dependencies
def test_partial_execution_multiple_output_nodes(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Create a chain of OUTPUT_NODEs
output_node1 = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=1.5)
output_node2 = g.node("TestOutputNodeWithSocketOutput", image=output_node1.out(0), value=2.0)
# Create regular output nodes
save1 = g.node("SaveImage", images=output_node1.out(0))
save2 = g.node("SaveImage", images=output_node2.out(0))
save3 = g.node("SaveImage", images=input2.out(0))
# Run targeting only save2
result = client.run(g, partial_execution_targets=[save2.id])
# Should run: input1, output_node1, output_node2, save2
assert result.was_executed(input1), "Input1 should have been executed"
assert result.was_executed(output_node1), "Output node 1 should have been executed (dependency)"
assert result.was_executed(output_node2), "Output node 2 should have been executed (dependency)"
assert result.was_executed(save2), "Save2 should have been executed"
# Should NOT run: input2, save1, save3
assert not result.did_run(input2), "Input2 should not have run"
assert not result.did_run(save1), "Save1 should not have run"
assert not result.did_run(save3), "Save3 should not have run"
# Empty partial execution list (should execute nothing)
def test_partial_execution_empty_list(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
_output1 = g.node("SaveImage", images=input1.out(0))
# Run with empty partial execution list
try:
_result = client.run(g, partial_execution_targets=[])
# Should get an error because no outputs are selected
assert False, "Should have raised an error for empty partial execution list"
except urllib.error.HTTPError:
pass # Expected behavior

View File

@@ -463,6 +463,25 @@ class TestParallelSleep(ComfyNodeABC):
"expand": g.finalize(),
}
class TestOutputNodeWithSocketOutput:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing"
OUTPUT_NODE = True
def process(self, image, value):
# Apply value scaling and return both as output and socket
result = image * value
return (result,)
TEST_NODE_CLASS_MAPPINGS = {
"TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage,
@@ -478,6 +497,7 @@ TEST_NODE_CLASS_MAPPINGS = {
"TestSamplingInExpansion": TestSamplingInExpansion,
"TestSleep": TestSleep,
"TestParallelSleep": TestParallelSleep,
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
}
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
@@ -495,4 +515,5 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestSamplingInExpansion": "Sampling In Expansion",
"TestSleep": "Test Sleep",
"TestParallelSleep": "Test Parallel Sleep",
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
}