mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 05:00:03 +00:00
Compare commits
27 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
32a95bba8a | ||
|
|
da1ad9b516 | ||
|
|
d044a24398 | ||
|
|
5be6fd09ff | ||
|
|
f69609bbd6 | ||
|
|
c012400240 | ||
|
|
03895dea7c | ||
|
|
84f9759424 | ||
|
|
7991341e89 | ||
|
|
140ffc7fdc | ||
|
|
182f90b5ec | ||
|
|
aebac22193 | ||
|
|
13aaa66ec2 | ||
|
|
5f582a9757 | ||
|
|
fbcc23945d | ||
|
|
3dfefc88d0 | ||
|
|
bff60b5cfc | ||
|
|
1e638a140b | ||
|
|
4696d74305 | ||
|
|
5ee381c058 | ||
|
|
4887743a2a | ||
|
|
97b8a2c26a | ||
|
|
97eb256a35 | ||
|
|
61b08d4ba6 | ||
|
|
da9dab7edd | ||
|
|
d2aaef029c | ||
|
|
0a3d062e06 |
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -1,2 +1,3 @@
|
||||
/web/assets/** linguist-generated
|
||||
/web/** linguist-vendored
|
||||
comfy_api_nodes/apis/__init__.py linguist-generated
|
||||
|
||||
@@ -111,7 +111,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
## Release Process
|
||||
|
||||
ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories:
|
||||
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||
|
||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||
- Releases a new stable version (e.g., v0.7.0)
|
||||
|
||||
@@ -130,10 +130,21 @@ class ModelFileManager:
|
||||
|
||||
for file_name in filenames:
|
||||
try:
|
||||
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
||||
result.append(relative_path)
|
||||
except:
|
||||
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
||||
full_path = os.path.join(dirpath, file_name)
|
||||
relative_path = os.path.relpath(full_path, directory)
|
||||
|
||||
# Get file metadata
|
||||
file_info = {
|
||||
"name": relative_path,
|
||||
"pathIndex": pathIndex,
|
||||
"modified": os.path.getmtime(full_path), # Add modification time
|
||||
"created": os.path.getctime(full_path), # Add creation time
|
||||
"size": os.path.getsize(full_path) # Add file size
|
||||
}
|
||||
result.append(file_info)
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
|
||||
continue
|
||||
|
||||
for d in subdirs:
|
||||
@@ -144,7 +155,7 @@ class ModelFileManager:
|
||||
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
||||
continue
|
||||
|
||||
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
|
||||
return result, dirs, time.perf_counter()
|
||||
|
||||
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
|
||||
dirname = os.path.dirname(filepath)
|
||||
|
||||
@@ -20,13 +20,15 @@ class FileInfo(TypedDict):
|
||||
path: str
|
||||
size: int
|
||||
modified: int
|
||||
created: int
|
||||
|
||||
|
||||
def get_file_info(path: str, relative_to: str) -> FileInfo:
|
||||
return {
|
||||
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
|
||||
"size": os.path.getsize(path),
|
||||
"modified": os.path.getmtime(path)
|
||||
"modified": os.path.getmtime(path),
|
||||
"created": os.path.getctime(path)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import math
|
||||
import comfy.utils
|
||||
import logging
|
||||
|
||||
|
||||
class CONDRegular:
|
||||
@@ -10,12 +11,15 @@ class CONDRegular:
|
||||
def _copy_with(self, cond):
|
||||
return self.__class__(cond)
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
|
||||
def process_cond(self, batch_size, **kwargs):
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size))
|
||||
|
||||
def can_concat(self, other):
|
||||
if self.cond.shape != other.cond.shape:
|
||||
return False
|
||||
if self.cond.device != other.cond.device:
|
||||
logging.warning("WARNING: conds not on same device, skipping concat.")
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
@@ -29,14 +33,14 @@ class CONDRegular:
|
||||
|
||||
|
||||
class CONDNoiseShape(CONDRegular):
|
||||
def process_cond(self, batch_size, device, area, **kwargs):
|
||||
def process_cond(self, batch_size, area, **kwargs):
|
||||
data = self.cond
|
||||
if area is not None:
|
||||
dims = len(area) // 2
|
||||
for i in range(dims):
|
||||
data = data.narrow(i + 2, area[i + dims], area[i])
|
||||
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
|
||||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size))
|
||||
|
||||
|
||||
class CONDCrossAttn(CONDRegular):
|
||||
@@ -51,6 +55,9 @@ class CONDCrossAttn(CONDRegular):
|
||||
diff = mult_min // min(s1[1], s2[1])
|
||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||
return False
|
||||
if self.cond.device != other.cond.device:
|
||||
logging.warning("WARNING: conds not on same device: skipping concat.")
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
@@ -73,7 +80,7 @@ class CONDConstant(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
def process_cond(self, batch_size, **kwargs):
|
||||
return self._copy_with(self.cond)
|
||||
|
||||
def can_concat(self, other):
|
||||
@@ -92,10 +99,10 @@ class CONDList(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
def process_cond(self, batch_size, **kwargs):
|
||||
out = []
|
||||
for c in self.cond:
|
||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
|
||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size))
|
||||
|
||||
return self._copy_with(out)
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ import comfy.model_detection
|
||||
import comfy.model_patcher
|
||||
import comfy.ops
|
||||
import comfy.latent_formats
|
||||
import comfy.model_base
|
||||
|
||||
import comfy.cldm.cldm
|
||||
import comfy.t2i_adapter.adapter
|
||||
@@ -43,7 +44,6 @@ if TYPE_CHECKING:
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
#print(current_batch_size, target_batch_size)
|
||||
if current_batch_size == 1:
|
||||
return tensor
|
||||
|
||||
@@ -265,12 +265,12 @@ class ControlNet(ControlBase):
|
||||
for c in self.extra_conds:
|
||||
temp = cond.get(c, None)
|
||||
if temp is not None:
|
||||
extra[c] = temp.to(dtype)
|
||||
extra[c] = comfy.model_base.convert_tensor(temp, dtype, x_noisy.device)
|
||||
|
||||
timestep = self.model_sampling_current.timestep(t)
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), **extra)
|
||||
return self.control_merge(control, control_prev, output_dtype=None)
|
||||
|
||||
def copy(self):
|
||||
|
||||
@@ -58,7 +58,8 @@ def is_odd(n: int) -> bool:
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
return x * torch.sigmoid(x)
|
||||
# x * sigmoid(x)
|
||||
return torch.nn.functional.silu(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
|
||||
@@ -36,7 +36,7 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x*torch.sigmoid(x)
|
||||
return torch.nn.functional.silu(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
|
||||
400
comfy/ldm/qwen_image/model.py
Normal file
400
comfy/ldm/qwen_image/model.py
Normal file
@@ -0,0 +1,400 @@
|
||||
# https://github.com/QwenLM/Qwen-Image (Apache 2.0)
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple
|
||||
from einops import repeat
|
||||
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(dim_in, dim_out, bias=bias, dtype=dtype, device=device)
|
||||
self.approximate = approximate
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = F.gelu(hidden_states, approximate=self.approximate)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: int = 4,
|
||||
dropout: float = 0.0,
|
||||
inner_dim=None,
|
||||
bias: bool = True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
self.net.append(GELU(dim, inner_dim, approximate="tanh", bias=bias, dtype=dtype, device=device, operations=operations))
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
self.net.append(operations.Linear(inner_dim, dim_out, bias=bias, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def apply_rotary_emb(x, freqs_cis):
|
||||
if x.shape[1] == 0:
|
||||
return x
|
||||
|
||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x.shape)
|
||||
|
||||
|
||||
class QwenTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=256,
|
||||
time_embed_dim=embedding_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
|
||||
def forward(self, timestep, hidden_states):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
dim_head: int = 64,
|
||||
heads: int = 8,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
eps: float = 1e-5,
|
||||
out_bias: bool = True,
|
||||
out_dim: int = None,
|
||||
out_context_dim: int = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.inner_kv_dim = self.inner_dim
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
||||
self.dropout = dropout
|
||||
|
||||
# Q/K normalization
|
||||
self.norm_q = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||
self.norm_k = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
|
||||
# Image stream projections
|
||||
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
# Text stream projections
|
||||
self.add_q_proj = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
# Output projections
|
||||
self.to_out = nn.ModuleList([
|
||||
operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device),
|
||||
nn.Dropout(dropout)
|
||||
])
|
||||
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor, # Image stream
|
||||
encoder_hidden_states: torch.FloatTensor = None, # Text stream
|
||||
encoder_hidden_states_mask: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
|
||||
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
|
||||
img_query = self.norm_q(img_query)
|
||||
img_key = self.norm_k(img_key)
|
||||
txt_query = self.norm_added_q(txt_query)
|
||||
txt_key = self.norm_added_k(txt_key)
|
||||
|
||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||
|
||||
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
||||
|
||||
joint_query = joint_query.flatten(start_dim=2)
|
||||
joint_key = joint_key.flatten(start_dim=2)
|
||||
joint_value = joint_value.flatten(start_dim=2)
|
||||
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask)
|
||||
|
||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||
|
||||
img_attn_output = self.to_out[0](img_attn_output)
|
||||
img_attn_output = self.to_out[1](img_attn_output)
|
||||
txt_attn_output = self.to_add_out(txt_attn_output)
|
||||
|
||||
return img_attn_output, txt_attn_output
|
||||
|
||||
|
||||
class QwenImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
eps: float = 1e-6,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
self.img_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.img_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
||||
self.img_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
||||
self.img_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.txt_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
||||
self.txt_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
||||
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
eps=eps,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def _modulate(self, x, mod_params):
|
||||
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states_mask: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
img_mod_params = self.img_mod(temb)
|
||||
txt_mod_params = self.txt_mod(temb)
|
||||
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
||||
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
||||
|
||||
img_normed = self.img_norm1(hidden_states)
|
||||
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
|
||||
txt_normed = self.txt_norm1(encoder_hidden_states)
|
||||
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
|
||||
|
||||
img_attn_output, txt_attn_output = self.attn(
|
||||
hidden_states=img_modulated,
|
||||
encoder_hidden_states=txt_modulated,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + img_gate1 * img_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||
|
||||
img_normed2 = self.img_norm2(hidden_states)
|
||||
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
||||
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
|
||||
|
||||
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
||||
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
conditioning_embedding_dim: int,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
bias=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias, dtype=dtype, device=device)
|
||||
self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine=False, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear(self.silu(conditioning_embedding))
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
|
||||
class QwenImageTransformer2DModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 64,
|
||||
out_channels: Optional[int] = 16,
|
||||
num_layers: int = 60,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 3584,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||
image_model=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.patch_size = patch_size
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||
|
||||
self.time_text_embed = QwenTimestepProjEmbeddings(
|
||||
embedding_dim=self.inner_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
|
||||
self.txt_norm = operations.RMSNorm(joint_attention_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_in = operations.Linear(in_channels, self.inner_dim, dtype=dtype, device=device)
|
||||
self.txt_in = operations.Linear(joint_attention_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
QwenImageTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def pos_embeds(self, x, context):
|
||||
bs, c, t, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_start = round(max(h_len, w_len))
|
||||
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask=None,
|
||||
guidance: torch.Tensor = None,
|
||||
**kwargs
|
||||
):
|
||||
timestep = timesteps
|
||||
encoder_hidden_states = context
|
||||
encoder_hidden_states_mask = attention_mask
|
||||
|
||||
image_rotary_emb = self.pos_embeds(x, context)
|
||||
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
||||
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance * 1000
|
||||
|
||||
temb = (
|
||||
self.time_text_embed(timestep, hidden_states)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||
)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
|
||||
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
||||
@@ -769,8 +769,7 @@ class CameraWanModel(WanModel):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
if self.control_adapter is not None and camera_conditions is not None:
|
||||
x_camera = self.control_adapter(camera_conditions).to(x.dtype)
|
||||
x = x + x_camera
|
||||
x = x + self.control_adapter(camera_conditions).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
|
||||
@@ -24,12 +24,17 @@ class CausalConv3d(ops.Conv3d):
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def forward(self, x, cache_x=None):
|
||||
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
|
||||
if cache_list is not None:
|
||||
cache_x = cache_list[cache_idx]
|
||||
cache_list[cache_idx] = None
|
||||
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
del cache_x
|
||||
x = F.pad(x, padding)
|
||||
|
||||
return super().forward(x)
|
||||
@@ -166,7 +171,7 @@ class ResidualBlock(nn.Module):
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
h = self.shortcut(x)
|
||||
old_x = x
|
||||
for layer in self.residual:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
@@ -178,12 +183,12 @@ class ResidualBlock(nn.Module):
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + h
|
||||
return x + self.shortcut(old_x)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
|
||||
@@ -151,7 +151,7 @@ class ResidualBlock(nn.Module):
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
x = layer(x, feat_cache[idx])
|
||||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
@@ -42,6 +42,7 @@ import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -106,10 +107,12 @@ def model_sampling(model_config, model_type):
|
||||
return ModelSampling(model_config)
|
||||
|
||||
|
||||
def convert_tensor(extra, dtype):
|
||||
def convert_tensor(extra, dtype, device):
|
||||
if hasattr(extra, "dtype"):
|
||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||
extra = extra.to(dtype)
|
||||
extra = comfy.model_management.cast_to_device(extra, device, dtype)
|
||||
else:
|
||||
extra = comfy.model_management.cast_to_device(extra, device, None)
|
||||
return extra
|
||||
|
||||
|
||||
@@ -160,7 +163,7 @@ class BaseModel(torch.nn.Module):
|
||||
xc = self.model_sampling.calculate_input(sigma, x)
|
||||
|
||||
if c_concat is not None:
|
||||
xc = torch.cat([xc] + [c_concat], dim=1)
|
||||
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
|
||||
|
||||
context = c_crossattn
|
||||
dtype = self.get_dtype()
|
||||
@@ -169,20 +172,21 @@ class BaseModel(torch.nn.Module):
|
||||
dtype = self.manual_cast_dtype
|
||||
|
||||
xc = xc.to(dtype)
|
||||
device = xc.device
|
||||
t = self.model_sampling.timestep(t).float()
|
||||
if context is not None:
|
||||
context = context.to(dtype)
|
||||
context = comfy.model_management.cast_to_device(context, device, dtype)
|
||||
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
extra = kwargs[o]
|
||||
|
||||
if hasattr(extra, "dtype"):
|
||||
extra = convert_tensor(extra, dtype)
|
||||
extra = convert_tensor(extra, dtype, device)
|
||||
elif isinstance(extra, list):
|
||||
ex = []
|
||||
for ext in extra:
|
||||
ex.append(convert_tensor(ext, dtype))
|
||||
ex.append(convert_tensor(ext, dtype, device))
|
||||
extra = ex
|
||||
extra_conds[o] = extra
|
||||
|
||||
@@ -398,7 +402,7 @@ class SD21UNCLIP(BaseModel):
|
||||
unclip_conditioning = kwargs.get("unclip_conditioning", None)
|
||||
device = kwargs["device"]
|
||||
if unclip_conditioning is None:
|
||||
return torch.zeros((1, self.adm_channels))
|
||||
return torch.zeros((1, self.adm_channels), device=device)
|
||||
else:
|
||||
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
|
||||
|
||||
@@ -612,9 +616,11 @@ class IP2P:
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros_like(noise)
|
||||
else:
|
||||
image = image.to(device=device)
|
||||
|
||||
if image.shape[1:] != noise.shape[1:]:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
return self.process_ip2p_image_in(image)
|
||||
@@ -693,7 +699,7 @@ class StableCascade_B(BaseModel):
|
||||
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
|
||||
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
|
||||
|
||||
out["effnet"] = comfy.conds.CONDRegular(prior)
|
||||
out["effnet"] = comfy.conds.CONDRegular(prior.to(device=noise.device))
|
||||
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
||||
return out
|
||||
|
||||
@@ -1158,10 +1164,10 @@ class WAN21_Vace(WAN21):
|
||||
|
||||
vace_frames_out = []
|
||||
for j in range(len(vace_frames)):
|
||||
vf = vace_frames[j].clone()
|
||||
vf = vace_frames[j].to(device=noise.device, dtype=noise.dtype, copy=True)
|
||||
for i in range(0, vf.shape[1], 16):
|
||||
vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
|
||||
vf = torch.cat([vf, mask[j]], dim=1)
|
||||
vf = torch.cat([vf, mask[j].to(device=noise.device, dtype=noise.dtype)], dim=1)
|
||||
vace_frames_out.append(vf)
|
||||
|
||||
vace_frames = torch.stack(vace_frames_out, dim=1)
|
||||
@@ -1303,3 +1309,14 @@ class Omnigen2(BaseModel):
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
return out
|
||||
|
||||
class QwenImage(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
@@ -481,6 +481,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["timestep_scale"] = 1000.0
|
||||
return dit_config
|
||||
|
||||
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "qwen_image"
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
@@ -867,7 +872,7 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
||||
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
|
||||
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
||||
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3
|
||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||
|
||||
@@ -89,7 +89,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
conditioning = {}
|
||||
model_conds = conds["model_conds"]
|
||||
for c in model_conds:
|
||||
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
||||
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], area=area)
|
||||
|
||||
hooks = conds.get('hooks', None)
|
||||
control = conds.get('control', None)
|
||||
|
||||
12
comfy/sd.py
12
comfy/sd.py
@@ -47,6 +47,7 @@ import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.hidream
|
||||
import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -771,6 +772,7 @@ class CLIPType(Enum):
|
||||
CHROMA = 15
|
||||
ACE = 16
|
||||
OMNIGEN2 = 17
|
||||
QWEN_IMAGE = 18
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@@ -791,6 +793,7 @@ class TEModel(Enum):
|
||||
T5_XXL_OLD = 8
|
||||
GEMMA_2_2B = 9
|
||||
QWEN25_3B = 10
|
||||
QWEN25_7B = 11
|
||||
|
||||
def detect_te_model(sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
@@ -812,7 +815,11 @@ def detect_te_model(sd):
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
return TEModel.GEMMA_2_2B
|
||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||
return TEModel.QWEN25_3B
|
||||
weight = sd['model.layers.0.self_attn.k_proj.bias']
|
||||
if weight.shape[0] == 256:
|
||||
return TEModel.QWEN25_3B
|
||||
if weight.shape[0] == 512:
|
||||
return TEModel.QWEN25_7B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
return TEModel.LLAMA3_8
|
||||
return None
|
||||
@@ -917,6 +924,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif te_model == TEModel.QWEN25_3B:
|
||||
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
|
||||
elif te_model == TEModel.QWEN25_7B:
|
||||
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
||||
else:
|
||||
# clip_l
|
||||
if clip_type == CLIPType.SD3:
|
||||
|
||||
@@ -19,6 +19,7 @@ import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -1229,7 +1230,36 @@ class Omnigen2(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
||||
|
||||
class QwenImage(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "qwen_image",
|
||||
}
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 1.15,
|
||||
}
|
||||
|
||||
memory_usage_factor = 1.8 #TODO
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Wan21
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.QwenImage(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -43,6 +43,23 @@ class Qwen25_3BConfig:
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = True
|
||||
|
||||
@dataclass
|
||||
class Qwen25_7BVLI_Config:
|
||||
vocab_size: int = 152064
|
||||
hidden_size: int = 3584
|
||||
intermediate_size: int = 18944
|
||||
num_hidden_layers: int = 28
|
||||
num_attention_heads: int = 28
|
||||
num_key_value_heads: int = 4
|
||||
max_position_embeddings: int = 128000
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = True
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
vocab_size: int = 256000
|
||||
@@ -348,6 +365,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen25_7BVLI_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
71
comfy/text_encoders/qwen_image.py
Normal file
71
comfy/text_encoders/qwen_image.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from transformers import Qwen2Tokenizer
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.llama
|
||||
import os
|
||||
import torch
|
||||
import numbers
|
||||
|
||||
class Qwen25_7BVLITokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=3584, embedding_key='qwen25_7b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer)
|
||||
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
|
||||
|
||||
|
||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class QwenImageTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||
tok_pairs = token_weight_pairs["qwen25_7b"][0]
|
||||
count_im_start = 0
|
||||
for i, v in enumerate(tok_pairs):
|
||||
elem = v[0]
|
||||
if not torch.is_tensor(elem):
|
||||
if isinstance(elem, numbers.Integral):
|
||||
if elem == 151644 and count_im_start < 2:
|
||||
template_end = i
|
||||
count_im_start += 1
|
||||
|
||||
if out.shape[1] > (template_end + 3):
|
||||
if tok_pairs[template_end + 1][0] == 872:
|
||||
if tok_pairs[template_end + 2][0] == 198:
|
||||
template_end += 3
|
||||
|
||||
out = out[:, template_end:]
|
||||
|
||||
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
|
||||
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
|
||||
extra.pop("attention_mask") # attention mask is useless if no masked elements
|
||||
|
||||
return out, pooled, extra
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||
class QwenImageTEModel_(QwenImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return QwenImageTEModel_
|
||||
86
comfy_api/generate_api_stubs.py
Normal file
86
comfy_api/generate_api_stubs.py
Normal file
@@ -0,0 +1,86 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to generate .pyi stub files for the synchronous API wrappers.
|
||||
This allows generating stubs without running the full ComfyUI application.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import importlib
|
||||
|
||||
# Add ComfyUI to path so we can import modules
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from comfy_api.internal.async_to_sync import AsyncToSyncConverter
|
||||
from comfy_api.version_list import supported_versions
|
||||
|
||||
|
||||
def generate_stubs_for_module(module_name: str) -> None:
|
||||
"""Generate stub files for a specific module that exports ComfyAPI and ComfyAPISync."""
|
||||
try:
|
||||
# Import the module
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Check if module has ComfyAPISync (the sync wrapper)
|
||||
if hasattr(module, "ComfyAPISync"):
|
||||
# Module already has a sync class
|
||||
api_class = getattr(module, "ComfyAPI", None)
|
||||
sync_class = getattr(module, "ComfyAPISync")
|
||||
|
||||
if api_class:
|
||||
# Generate the stub file
|
||||
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
|
||||
logging.info(f"Generated stub file for {module_name}")
|
||||
else:
|
||||
logging.warning(
|
||||
f"Module {module_name} has ComfyAPISync but no ComfyAPI"
|
||||
)
|
||||
|
||||
elif hasattr(module, "ComfyAPI"):
|
||||
# Module only has async API, need to create sync wrapper first
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
|
||||
api_class = getattr(module, "ComfyAPI")
|
||||
sync_class = create_sync_class(api_class)
|
||||
|
||||
# Generate the stub file
|
||||
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
|
||||
logging.info(f"Generated stub file for {module_name}")
|
||||
else:
|
||||
logging.warning(
|
||||
f"Module {module_name} does not export ComfyAPI or ComfyAPISync"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to generate stub for {module_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to generate all API stub files."""
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
logging.info("Starting stub generation...")
|
||||
|
||||
# Dynamically get module names from supported_versions
|
||||
api_modules = []
|
||||
for api_class in supported_versions:
|
||||
# Extract module name from the class
|
||||
module_name = api_class.__module__
|
||||
if module_name not in api_modules:
|
||||
api_modules.append(module_name)
|
||||
|
||||
logging.info(f"Found {len(api_modules)} API modules: {api_modules}")
|
||||
|
||||
# Generate stubs for each module
|
||||
for module_name in api_modules:
|
||||
generate_stubs_for_module(module_name)
|
||||
|
||||
logging.info("Stub generation complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,8 +1,16 @@
|
||||
from .basic_types import ImageInput, AudioInput
|
||||
from .video_types import VideoInput
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._input import (
|
||||
ImageInput,
|
||||
AudioInput,
|
||||
MaskInput,
|
||||
LatentInput,
|
||||
VideoInput,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ImageInput",
|
||||
"AudioInput",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
"VideoInput",
|
||||
]
|
||||
|
||||
@@ -1,20 +1,14 @@
|
||||
import torch
|
||||
from typing import TypedDict
|
||||
|
||||
ImageInput = torch.Tensor
|
||||
"""
|
||||
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
|
||||
"""
|
||||
|
||||
class AudioInput(TypedDict):
|
||||
"""
|
||||
TypedDict representing audio input.
|
||||
"""
|
||||
|
||||
waveform: torch.Tensor
|
||||
"""
|
||||
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
|
||||
"""
|
||||
|
||||
sample_rate: int
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._input.basic_types import (
|
||||
ImageInput,
|
||||
AudioInput,
|
||||
MaskInput,
|
||||
LatentInput,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ImageInput",
|
||||
"AudioInput",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
]
|
||||
|
||||
@@ -1,85 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
import io
|
||||
import av
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._input.video_types import VideoInput
|
||||
|
||||
class VideoInput(ABC):
|
||||
"""
|
||||
Abstract base class for video input types.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_components(self) -> VideoComponents:
|
||||
"""
|
||||
Abstract method to get the video components (images, audio, and frame rate).
|
||||
|
||||
Returns:
|
||||
VideoComponents containing images, audio, and frame rate
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
Abstract method to save the video input to a file.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
||||
"""
|
||||
Get a streamable source for the video. This allows processing without
|
||||
loading the entire video into memory.
|
||||
|
||||
Returns:
|
||||
Either a file path (str) or a BytesIO object that can be opened with av.
|
||||
|
||||
Default implementation creates a BytesIO buffer, but subclasses should
|
||||
override this for better performance when possible.
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
self.save_to(buffer)
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
# Provide a default implementation, but subclasses can provide optimized versions
|
||||
# if possible.
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the dimensions of the video input.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
components = self.get_components()
|
||||
return components.images.shape[2], components.images.shape[1]
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
components = self.get_components()
|
||||
frame_count = components.images.shape[0]
|
||||
return float(frame_count / components.frame_rate)
|
||||
|
||||
def get_container_format(self) -> str:
|
||||
"""
|
||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||
|
||||
Returns:
|
||||
Container format as string
|
||||
"""
|
||||
# Default implementation - subclasses should override for better performance
|
||||
source = self.get_stream_source()
|
||||
with av.open(source, mode="r") as container:
|
||||
return container.format.name
|
||||
__all__ = [
|
||||
"VideoInput",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from .video_types import VideoFromFile, VideoFromComponents
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||
|
||||
__all__ = [
|
||||
# Implementations
|
||||
"VideoFromFile",
|
||||
"VideoFromComponents",
|
||||
]
|
||||
|
||||
@@ -1,324 +1,2 @@
|
||||
from __future__ import annotations
|
||||
from av.container import InputContainer
|
||||
from av.subtitles.stream import SubtitleStream
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.input import AudioInput
|
||||
import av
|
||||
import io
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
def container_to_output_format(container_format: str | None) -> str | None:
|
||||
"""
|
||||
A container's `format` may be a comma-separated list of formats.
|
||||
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
|
||||
However, writing to a file/stream with `av.open` requires a single format,
|
||||
or `None` to auto-detect.
|
||||
"""
|
||||
if not container_format:
|
||||
return None # Auto-detect
|
||||
|
||||
if "," not in container_format:
|
||||
return container_format
|
||||
|
||||
formats = container_format.split(",")
|
||||
return formats[0]
|
||||
|
||||
|
||||
def get_open_write_kwargs(
|
||||
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||
) -> dict:
|
||||
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
|
||||
open_kwargs = {
|
||||
"mode": "w",
|
||||
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
|
||||
"options": {"movflags": "use_metadata_tags"},
|
||||
}
|
||||
|
||||
is_write_to_buffer = isinstance(dest, io.BytesIO)
|
||||
if is_write_to_buffer:
|
||||
# Set output format explicitly, since it cannot be inferred from file extension
|
||||
if to_format == VideoContainer.AUTO:
|
||||
to_format = container_format.lower()
|
||||
elif isinstance(to_format, str):
|
||||
to_format = to_format.lower()
|
||||
open_kwargs["format"] = container_to_output_format(to_format)
|
||||
|
||||
return open_kwargs
|
||||
|
||||
|
||||
class VideoFromFile(VideoInput):
|
||||
"""
|
||||
Class representing video input from a file.
|
||||
"""
|
||||
|
||||
def __init__(self, file: str | io.BytesIO):
|
||||
"""
|
||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||
containing the file contents.
|
||||
"""
|
||||
self.__file = file
|
||||
|
||||
def get_stream_source(self) -> str | io.BytesIO:
|
||||
"""
|
||||
Return the underlying file source for efficient streaming.
|
||||
This avoids unnecessary memory copies when the source is already a file path.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
return self.__file
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the dimensions of the video input.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
for stream in container.streams:
|
||||
if stream.type == 'video':
|
||||
assert isinstance(stream, av.VideoStream)
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
if container.duration is not None:
|
||||
return float(container.duration / av.time_base)
|
||||
|
||||
# Fallback: calculate from frame count and frame rate
|
||||
video_stream = next(
|
||||
(s for s in container.streams if s.type == "video"), None
|
||||
)
|
||||
if video_stream and video_stream.frames and video_stream.average_rate:
|
||||
return float(video_stream.frames / video_stream.average_rate)
|
||||
|
||||
# Last resort: decode frames to count them
|
||||
if video_stream and video_stream.average_rate:
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
if frame_count > 0:
|
||||
return float(frame_count / video_stream.average_rate)
|
||||
|
||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||
|
||||
def get_container_format(self) -> str:
|
||||
"""
|
||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||
|
||||
Returns:
|
||||
Container format as string
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
return container.format.name
|
||||
|
||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||
# Get video frames
|
||||
frames = []
|
||||
for frame in container.decode(video=0):
|
||||
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
||||
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
||||
frames.append(img)
|
||||
|
||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
||||
|
||||
# Get frame rate
|
||||
video_stream = next(s for s in container.streams if s.type == 'video')
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
||||
|
||||
# Get audio if available
|
||||
audio = None
|
||||
try:
|
||||
container.seek(0) # Reset the container to the beginning
|
||||
for stream in container.streams:
|
||||
if stream.type != 'audio':
|
||||
continue
|
||||
assert isinstance(stream, av.AudioStream)
|
||||
audio_frames = []
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
assert isinstance(frame, av.AudioFrame)
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
||||
})
|
||||
except StopIteration:
|
||||
pass # No audio stream
|
||||
|
||||
metadata = container.metadata
|
||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
return self.get_components_internal(container)
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str | io.BytesIO,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
container_format = container.format.name
|
||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||
reuse_streams = True
|
||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||
reuse_streams = False
|
||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||
reuse_streams = False
|
||||
|
||||
if not reuse_streams:
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path,
|
||||
format=format,
|
||||
codec=codec,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
|
||||
open_kwargs = get_open_write_kwargs(path, container_format, format)
|
||||
with av.open(path, **open_kwargs) as output_container:
|
||||
# Copy over the original metadata
|
||||
for key, value in container.metadata.items():
|
||||
if metadata is None or key not in metadata:
|
||||
output_container.metadata[key] = value
|
||||
|
||||
# Add our new metadata
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, str):
|
||||
output_container.metadata[key] = value
|
||||
else:
|
||||
output_container.metadata[key] = json.dumps(value)
|
||||
|
||||
# Add streams to the new container
|
||||
stream_map = {}
|
||||
for stream in streams:
|
||||
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
|
||||
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
|
||||
stream_map[stream] = out_stream
|
||||
|
||||
# Write packets to the new container
|
||||
for packet in container.demux():
|
||||
if packet.stream in stream_map and packet.dts is not None:
|
||||
packet.stream = stream_map[packet.stream]
|
||||
output_container.mux(packet)
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
"""
|
||||
Class representing video input from tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, components: VideoComponents):
|
||||
self.__components = components
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
return VideoComponents(
|
||||
images=self.__components.images,
|
||||
audio=self.__components.audio,
|
||||
frame_rate=self.__components.frame_rate
|
||||
)
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
raise ValueError("Only H264 codec is supported for now")
|
||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
||||
# Add metadata before writing any streams
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
output.metadata[key] = json.dumps(value)
|
||||
|
||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||
# Create a video stream
|
||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||
video_stream.width = self.__components.images.shape[2]
|
||||
video_stream.height = self.__components.images.shape[1]
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
|
||||
# Create an audio stream
|
||||
audio_sample_rate = 1
|
||||
audio_stream: Optional[av.AudioStream] = None
|
||||
if self.__components.audio:
|
||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||
audio_stream.sample_rate = audio_sample_rate
|
||||
audio_stream.format = 'fltp'
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||
packet = video_stream.encode(frame)
|
||||
output.mux(packet)
|
||||
|
||||
# Flush video
|
||||
packet = video_stream.encode(None)
|
||||
output.mux(packet)
|
||||
|
||||
if audio_stream and self.__components.audio:
|
||||
# Encode audio
|
||||
samples_per_frame = int(audio_sample_rate / frame_rate)
|
||||
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
|
||||
for i in range(num_frames):
|
||||
start = i * samples_per_frame
|
||||
end = start + samples_per_frame
|
||||
# TODO(Feature) - Add support for stereo audio
|
||||
chunk = (
|
||||
self.__components.audio["waveform"][0, 0, start:end]
|
||||
.unsqueeze(0)
|
||||
.contiguous()
|
||||
.numpy()
|
||||
)
|
||||
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
|
||||
audio_frame.sample_rate = audio_sample_rate
|
||||
audio_frame.pts = i * samples_per_frame
|
||||
for packet in audio_stream.encode(audio_frame):
|
||||
output.mux(packet)
|
||||
|
||||
# Flush audio
|
||||
for packet in audio_stream.encode(None):
|
||||
output.mux(packet)
|
||||
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._input_impl.video_types import * # noqa: F403
|
||||
|
||||
150
comfy_api/internal/__init__.py
Normal file
150
comfy_api/internal/__init__.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# Internal infrastructure for ComfyAPI
|
||||
from .api_registry import (
|
||||
ComfyAPIBase as ComfyAPIBase,
|
||||
ComfyAPIWithVersion as ComfyAPIWithVersion,
|
||||
register_versions as register_versions,
|
||||
get_all_versions as get_all_versions,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
from dataclasses import asdict
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
||||
def first_real_override(cls: type, name: str, *, base: type=None) -> Optional[Callable]:
|
||||
"""Return the *callable* override of `name` visible on `cls`, or None if every
|
||||
implementation up to (and including) `base` is the placeholder defined on `base`.
|
||||
|
||||
If base is not provided, it will assume cls has a GET_BASE_CLASS
|
||||
"""
|
||||
if base is None:
|
||||
if not hasattr(cls, "GET_BASE_CLASS"):
|
||||
raise ValueError("base is required if cls does not have a GET_BASE_CLASS; is this a valid ComfyNode subclass?")
|
||||
base = cls.GET_BASE_CLASS()
|
||||
base_attr = getattr(base, name, None)
|
||||
if base_attr is None:
|
||||
return None
|
||||
base_func = base_attr.__func__
|
||||
for c in cls.mro(): # NodeB, NodeA, ComfyNode, object …
|
||||
if c is base: # reached the placeholder – we're done
|
||||
break
|
||||
if name in c.__dict__: # first class that *defines* the attr
|
||||
func = getattr(c, name).__func__
|
||||
if func is not base_func: # real override
|
||||
return getattr(cls, name) # bound to *cls*
|
||||
return None
|
||||
|
||||
|
||||
class _ComfyNodeInternal:
|
||||
"""Class that all V3-based APIs inherit from for ComfyNode.
|
||||
|
||||
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
|
||||
@classmethod
|
||||
def GET_NODE_INFO_V1(cls):
|
||||
...
|
||||
|
||||
|
||||
class _NodeOutputInternal:
|
||||
"""Class that all V3-based APIs inherit from for NodeOutput.
|
||||
|
||||
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
|
||||
...
|
||||
|
||||
|
||||
def as_pruned_dict(dataclass_obj):
|
||||
'''Return dict of dataclass object with pruned None values.'''
|
||||
return prune_dict(asdict(dataclass_obj))
|
||||
|
||||
def prune_dict(d: dict):
|
||||
return {k: v for k,v in d.items() if v is not None}
|
||||
|
||||
|
||||
def is_class(obj):
|
||||
'''
|
||||
Returns True if is a class type.
|
||||
Returns False if is a class instance.
|
||||
'''
|
||||
return isinstance(obj, type)
|
||||
|
||||
|
||||
def copy_class(cls: type) -> type:
|
||||
'''
|
||||
Copy a class and its attributes.
|
||||
'''
|
||||
if cls is None:
|
||||
return None
|
||||
cls_dict = {
|
||||
k: v for k, v in cls.__dict__.items()
|
||||
if k not in ('__dict__', '__weakref__', '__module__', '__doc__')
|
||||
}
|
||||
# new class
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(cls,),
|
||||
cls_dict
|
||||
)
|
||||
# metadata preservation
|
||||
new_cls.__module__ = cls.__module__
|
||||
new_cls.__doc__ = cls.__doc__
|
||||
return new_cls
|
||||
|
||||
|
||||
class classproperty(object):
|
||||
def __init__(self, f):
|
||||
self.f = f
|
||||
def __get__(self, obj, owner):
|
||||
return self.f(owner)
|
||||
|
||||
|
||||
# NOTE: this was ai generated and validated by hand
|
||||
def shallow_clone_class(cls, new_name=None):
|
||||
'''
|
||||
Shallow clone a class while preserving super() functionality.
|
||||
'''
|
||||
new_name = new_name or f"{cls.__name__}Clone"
|
||||
# Include the original class in the bases to maintain proper inheritance
|
||||
new_bases = (cls,) + cls.__bases__
|
||||
return type(new_name, new_bases, dict(cls.__dict__))
|
||||
|
||||
# NOTE: this was ai generated and validated by hand
|
||||
def lock_class(cls):
|
||||
'''
|
||||
Lock a class so that its top-levelattributes cannot be modified.
|
||||
'''
|
||||
# Locked instance __setattr__
|
||||
def locked_instance_setattr(self, name, value):
|
||||
raise AttributeError(
|
||||
f"Cannot set attribute '{name}' on immutable instance of {type(self).__name__}"
|
||||
)
|
||||
# Locked metaclass
|
||||
class LockedMeta(type(cls)):
|
||||
def __setattr__(cls_, name, value):
|
||||
raise AttributeError(
|
||||
f"Cannot modify class attribute '{name}' on locked class '{cls_.__name__}'"
|
||||
)
|
||||
# Rebuild class with locked behavior
|
||||
locked_dict = dict(cls.__dict__)
|
||||
locked_dict['__setattr__'] = locked_instance_setattr
|
||||
|
||||
return LockedMeta(cls.__name__, cls.__bases__, locked_dict)
|
||||
|
||||
|
||||
def make_locked_method_func(type_obj, func, class_clone):
|
||||
"""
|
||||
Returns a function that, when called with **inputs, will execute:
|
||||
getattr(type_obj, func).__func__(lock_class(class_clone), **inputs)
|
||||
|
||||
Supports both synchronous and asynchronous methods.
|
||||
"""
|
||||
locked_class = lock_class(class_clone)
|
||||
method = getattr(type_obj, func).__func__
|
||||
|
||||
# Check if the original method is async
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
async def wrapped_async_func(**inputs):
|
||||
return await method(locked_class, **inputs)
|
||||
return wrapped_async_func
|
||||
else:
|
||||
def wrapped_func(**inputs):
|
||||
return method(locked_class, **inputs)
|
||||
return wrapped_func
|
||||
39
comfy_api/internal/api_registry.py
Normal file
39
comfy_api/internal/api_registry.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Type, List, NamedTuple
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from packaging import version as packaging_version
|
||||
|
||||
|
||||
class ComfyAPIBase(ProxiedSingleton):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class ComfyAPIWithVersion(NamedTuple):
|
||||
version: str
|
||||
api_class: Type[ComfyAPIBase]
|
||||
|
||||
|
||||
def parse_version(version_str: str) -> packaging_version.Version:
|
||||
"""
|
||||
Parses a version string into a packaging_version.Version object.
|
||||
Raises ValueError if the version string is invalid.
|
||||
"""
|
||||
if version_str == "latest":
|
||||
return packaging_version.parse("9999999.9999999.9999999")
|
||||
return packaging_version.parse(version_str)
|
||||
|
||||
|
||||
registered_versions: List[ComfyAPIWithVersion] = []
|
||||
|
||||
|
||||
def register_versions(versions: List[ComfyAPIWithVersion]):
|
||||
versions.sort(key=lambda x: parse_version(x.version))
|
||||
global registered_versions
|
||||
registered_versions = versions
|
||||
|
||||
|
||||
def get_all_versions() -> List[ComfyAPIWithVersion]:
|
||||
"""
|
||||
Returns a list of all registered ComfyAPI versions.
|
||||
"""
|
||||
return registered_versions
|
||||
987
comfy_api/internal/async_to_sync.py
Normal file
987
comfy_api/internal/async_to_sync.py
Normal file
@@ -0,0 +1,987 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import contextvars
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import textwrap
|
||||
import threading
|
||||
from enum import Enum
|
||||
from typing import Optional, Type, get_origin, get_args
|
||||
|
||||
|
||||
class TypeTracker:
|
||||
"""Tracks types discovered during stub generation for automatic import generation."""
|
||||
|
||||
def __init__(self):
|
||||
self.discovered_types = {} # type_name -> (module, qualname)
|
||||
self.builtin_types = {
|
||||
"Any",
|
||||
"Dict",
|
||||
"List",
|
||||
"Optional",
|
||||
"Tuple",
|
||||
"Union",
|
||||
"Set",
|
||||
"Sequence",
|
||||
"cast",
|
||||
"NamedTuple",
|
||||
"str",
|
||||
"int",
|
||||
"float",
|
||||
"bool",
|
||||
"None",
|
||||
"bytes",
|
||||
"object",
|
||||
"type",
|
||||
"dict",
|
||||
"list",
|
||||
"tuple",
|
||||
"set",
|
||||
}
|
||||
self.already_imported = (
|
||||
set()
|
||||
) # Track types already imported to avoid duplicates
|
||||
|
||||
def track_type(self, annotation):
|
||||
"""Track a type annotation and record its module/import info."""
|
||||
if annotation is None or annotation is type(None):
|
||||
return
|
||||
|
||||
# Skip builtins and typing module types we already import
|
||||
type_name = getattr(annotation, "__name__", None)
|
||||
if type_name and (
|
||||
type_name in self.builtin_types or type_name in self.already_imported
|
||||
):
|
||||
return
|
||||
|
||||
# Get module and qualname
|
||||
module = getattr(annotation, "__module__", None)
|
||||
qualname = getattr(annotation, "__qualname__", type_name or "")
|
||||
|
||||
# Skip types from typing module (they're already imported)
|
||||
if module == "typing":
|
||||
return
|
||||
|
||||
# Skip UnionType and GenericAlias from types module as they're handled specially
|
||||
if module == "types" and type_name in ("UnionType", "GenericAlias"):
|
||||
return
|
||||
|
||||
if module and module not in ["builtins", "__main__"]:
|
||||
# Store the type info
|
||||
if type_name:
|
||||
self.discovered_types[type_name] = (module, qualname)
|
||||
|
||||
def get_imports(self, main_module_name: str) -> list[str]:
|
||||
"""Generate import statements for all discovered types."""
|
||||
imports = []
|
||||
imports_by_module = {}
|
||||
|
||||
for type_name, (module, qualname) in sorted(self.discovered_types.items()):
|
||||
# Skip types from the main module (they're already imported)
|
||||
if main_module_name and module == main_module_name:
|
||||
continue
|
||||
|
||||
if module not in imports_by_module:
|
||||
imports_by_module[module] = []
|
||||
if type_name not in imports_by_module[module]: # Avoid duplicates
|
||||
imports_by_module[module].append(type_name)
|
||||
|
||||
# Generate import statements
|
||||
for module, types in sorted(imports_by_module.items()):
|
||||
if len(types) == 1:
|
||||
imports.append(f"from {module} import {types[0]}")
|
||||
else:
|
||||
imports.append(f"from {module} import {', '.join(sorted(set(types)))}")
|
||||
|
||||
return imports
|
||||
|
||||
|
||||
class AsyncToSyncConverter:
|
||||
"""
|
||||
Provides utilities to convert async classes to sync classes with proper type hints.
|
||||
"""
|
||||
|
||||
_thread_pool: Optional[concurrent.futures.ThreadPoolExecutor] = None
|
||||
_thread_pool_lock = threading.Lock()
|
||||
_thread_pool_initialized = False
|
||||
|
||||
@classmethod
|
||||
def get_thread_pool(cls, max_workers=None) -> concurrent.futures.ThreadPoolExecutor:
|
||||
"""Get or create the shared thread pool with proper thread-safe initialization."""
|
||||
# Fast path - check if already initialized without acquiring lock
|
||||
if cls._thread_pool_initialized:
|
||||
assert cls._thread_pool is not None, "Thread pool should be initialized"
|
||||
return cls._thread_pool
|
||||
|
||||
# Slow path - acquire lock and create pool if needed
|
||||
with cls._thread_pool_lock:
|
||||
if not cls._thread_pool_initialized:
|
||||
cls._thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=max_workers, thread_name_prefix="async_to_sync_"
|
||||
)
|
||||
cls._thread_pool_initialized = True
|
||||
|
||||
# This should never be None at this point, but add assertion for type checker
|
||||
assert cls._thread_pool is not None
|
||||
return cls._thread_pool
|
||||
|
||||
@classmethod
|
||||
def run_async_in_thread(cls, coro_func, *args, **kwargs):
|
||||
"""
|
||||
Run an async function in a separate thread from the thread pool.
|
||||
Blocks until the async function completes.
|
||||
Properly propagates contextvars between threads and manages event loops.
|
||||
"""
|
||||
# Capture current context - this includes all context variables
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# Store the result and any exception that occurs
|
||||
result_container: dict = {"result": None, "exception": None}
|
||||
|
||||
# Function that runs in the thread pool
|
||||
def run_in_thread():
|
||||
# Create new event loop for this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# Create the coroutine within the context
|
||||
async def run_with_context():
|
||||
# The coroutine function might access context variables
|
||||
return await coro_func(*args, **kwargs)
|
||||
|
||||
# Run the coroutine with the captured context
|
||||
# This ensures all context variables are available in the async function
|
||||
result = context.run(loop.run_until_complete, run_with_context())
|
||||
result_container["result"] = result
|
||||
except Exception as e:
|
||||
# Store the exception to re-raise in the calling thread
|
||||
result_container["exception"] = e
|
||||
finally:
|
||||
# Ensure event loop is properly closed to prevent warnings
|
||||
try:
|
||||
# Cancel any remaining tasks
|
||||
pending = asyncio.all_tasks(loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
# Run the loop briefly to handle cancellations
|
||||
if pending:
|
||||
loop.run_until_complete(
|
||||
asyncio.gather(*pending, return_exceptions=True)
|
||||
)
|
||||
except Exception:
|
||||
pass # Ignore errors during cleanup
|
||||
|
||||
# Close the event loop
|
||||
loop.close()
|
||||
|
||||
# Clear the event loop from the thread
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
# Submit to thread pool and wait for result
|
||||
thread_pool = cls.get_thread_pool()
|
||||
future = thread_pool.submit(run_in_thread)
|
||||
future.result() # Wait for completion
|
||||
|
||||
# Re-raise any exception that occurred in the thread
|
||||
if result_container["exception"] is not None:
|
||||
raise result_container["exception"]
|
||||
|
||||
return result_container["result"]
|
||||
|
||||
@classmethod
|
||||
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
|
||||
"""
|
||||
Creates a new class with synchronous versions of all async methods.
|
||||
|
||||
Args:
|
||||
async_class: The async class to convert
|
||||
thread_pool_size: Size of thread pool to use
|
||||
|
||||
Returns:
|
||||
A new class with sync versions of all async methods
|
||||
"""
|
||||
sync_class_name = "ComfyAPISyncStub"
|
||||
cls.get_thread_pool(thread_pool_size)
|
||||
|
||||
# Create a proper class with docstrings and proper base classes
|
||||
sync_class_dict = {
|
||||
"__doc__": async_class.__doc__,
|
||||
"__module__": async_class.__module__,
|
||||
"__qualname__": sync_class_name,
|
||||
"__orig_class__": async_class, # Store original class for typing references
|
||||
}
|
||||
|
||||
# Create __init__ method
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._async_instance = async_class(*args, **kwargs)
|
||||
|
||||
# Handle annotated class attributes (like execution: Execution)
|
||||
# Get all annotations from the class hierarchy
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
all_annotations.update(base_class.__annotations__)
|
||||
|
||||
# For each annotated attribute, check if it needs to be created or wrapped
|
||||
for attr_name, attr_type in all_annotations.items():
|
||||
if hasattr(self._async_instance, attr_name):
|
||||
# Attribute exists on the instance
|
||||
attr = getattr(self._async_instance, attr_name)
|
||||
# Check if this attribute needs a sync wrapper
|
||||
if hasattr(attr, "__class__"):
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
|
||||
if isinstance(attr, ProxiedSingleton):
|
||||
# Create a sync version of this attribute
|
||||
try:
|
||||
sync_attr_class = cls.create_sync_class(attr.__class__)
|
||||
# Create instance of the sync wrapper with the async instance
|
||||
sync_attr = object.__new__(sync_attr_class) # type: ignore
|
||||
sync_attr._async_instance = attr
|
||||
setattr(self, attr_name, sync_attr)
|
||||
except Exception:
|
||||
# If we can't create a sync version, keep the original
|
||||
setattr(self, attr_name, attr)
|
||||
else:
|
||||
# Not async, just copy the reference
|
||||
setattr(self, attr_name, attr)
|
||||
else:
|
||||
# Attribute doesn't exist, but is annotated - create it
|
||||
# This handles cases like execution: Execution
|
||||
if isinstance(attr_type, type):
|
||||
# Check if the type is defined as an inner class
|
||||
if hasattr(async_class, attr_type.__name__):
|
||||
inner_class = getattr(async_class, attr_type.__name__)
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
|
||||
# Create an instance of the inner class
|
||||
try:
|
||||
# For ProxiedSingleton classes, get or create the singleton instance
|
||||
if issubclass(inner_class, ProxiedSingleton):
|
||||
async_instance = inner_class.get_instance()
|
||||
else:
|
||||
async_instance = inner_class()
|
||||
|
||||
# Create sync wrapper
|
||||
sync_attr_class = cls.create_sync_class(inner_class)
|
||||
sync_attr = object.__new__(sync_attr_class) # type: ignore
|
||||
sync_attr._async_instance = async_instance
|
||||
setattr(self, attr_name, sync_attr)
|
||||
# Also set on the async instance for consistency
|
||||
setattr(self._async_instance, attr_name, async_instance)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Failed to create instance for {attr_name}: {e}"
|
||||
)
|
||||
|
||||
# Handle other instance attributes that might not be annotated
|
||||
for name, attr in inspect.getmembers(self._async_instance):
|
||||
if name.startswith("_") or hasattr(self, name):
|
||||
continue
|
||||
|
||||
# If attribute is an instance of a class, and that class is defined in the original class
|
||||
# we need to check if it needs a sync wrapper
|
||||
if isinstance(attr, object) and not isinstance(
|
||||
attr, (str, int, float, bool, list, dict, tuple)
|
||||
):
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
|
||||
if isinstance(attr, ProxiedSingleton):
|
||||
# Create a sync version of this nested class
|
||||
try:
|
||||
sync_attr_class = cls.create_sync_class(attr.__class__)
|
||||
# Create instance of the sync wrapper with the async instance
|
||||
sync_attr = object.__new__(sync_attr_class) # type: ignore
|
||||
sync_attr._async_instance = attr
|
||||
setattr(self, name, sync_attr)
|
||||
except Exception:
|
||||
# If we can't create a sync version, keep the original
|
||||
setattr(self, name, attr)
|
||||
|
||||
sync_class_dict["__init__"] = __init__
|
||||
|
||||
# Process methods from the async class
|
||||
for name, method in inspect.getmembers(
|
||||
async_class, predicate=inspect.isfunction
|
||||
):
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
# Extract the actual return type from a coroutine
|
||||
if inspect.iscoroutinefunction(method):
|
||||
# Create sync version of async method with proper signature
|
||||
@functools.wraps(method)
|
||||
def sync_method(self, *args, _method_name=name, **kwargs):
|
||||
async_method = getattr(self._async_instance, _method_name)
|
||||
return AsyncToSyncConverter.run_async_in_thread(
|
||||
async_method, *args, **kwargs
|
||||
)
|
||||
|
||||
# Add to the class dict
|
||||
sync_class_dict[name] = sync_method
|
||||
else:
|
||||
# For regular methods, create a proxy method
|
||||
@functools.wraps(method)
|
||||
def proxy_method(self, *args, _method_name=name, **kwargs):
|
||||
method = getattr(self._async_instance, _method_name)
|
||||
return method(*args, **kwargs)
|
||||
|
||||
# Add to the class dict
|
||||
sync_class_dict[name] = proxy_method
|
||||
|
||||
# Handle property access
|
||||
for name, prop in inspect.getmembers(
|
||||
async_class, lambda x: isinstance(x, property)
|
||||
):
|
||||
|
||||
def make_property(name, prop_obj):
|
||||
def getter(self):
|
||||
value = getattr(self._async_instance, name)
|
||||
if inspect.iscoroutinefunction(value):
|
||||
|
||||
def sync_fn(*args, **kwargs):
|
||||
return AsyncToSyncConverter.run_async_in_thread(
|
||||
value, *args, **kwargs
|
||||
)
|
||||
|
||||
return sync_fn
|
||||
return value
|
||||
|
||||
def setter(self, value):
|
||||
setattr(self._async_instance, name, value)
|
||||
|
||||
return property(getter, setter if prop_obj.fset else None)
|
||||
|
||||
sync_class_dict[name] = make_property(name, prop)
|
||||
|
||||
# Create the class
|
||||
sync_class = type(sync_class_name, (object,), sync_class_dict)
|
||||
|
||||
return sync_class
|
||||
|
||||
@classmethod
|
||||
def _format_type_annotation(
|
||||
cls, annotation, type_tracker: Optional[TypeTracker] = None
|
||||
) -> str:
|
||||
"""Convert a type annotation to its string representation for stub files."""
|
||||
if (
|
||||
annotation is inspect.Parameter.empty
|
||||
or annotation is inspect.Signature.empty
|
||||
):
|
||||
return "Any"
|
||||
|
||||
# Handle None type
|
||||
if annotation is type(None):
|
||||
return "None"
|
||||
|
||||
# Track the type if we have a tracker
|
||||
if type_tracker:
|
||||
type_tracker.track_type(annotation)
|
||||
|
||||
# Try using typing.get_origin/get_args for Python 3.8+
|
||||
try:
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
if origin is not None:
|
||||
# Track the origin type
|
||||
if type_tracker:
|
||||
type_tracker.track_type(origin)
|
||||
|
||||
# Get the origin name
|
||||
origin_name = getattr(origin, "__name__", str(origin))
|
||||
if "." in origin_name:
|
||||
origin_name = origin_name.split(".")[-1]
|
||||
|
||||
# Special handling for types.UnionType (Python 3.10+ pipe operator)
|
||||
# Convert to old-style Union for compatibility
|
||||
if str(origin) == "<class 'types.UnionType'>" or origin_name == "UnionType":
|
||||
origin_name = "Union"
|
||||
|
||||
# Format arguments recursively
|
||||
if args:
|
||||
formatted_args = []
|
||||
for arg in args:
|
||||
# Track each type in the union
|
||||
if type_tracker:
|
||||
type_tracker.track_type(arg)
|
||||
formatted_args.append(cls._format_type_annotation(arg, type_tracker))
|
||||
return f"{origin_name}[{', '.join(formatted_args)}]"
|
||||
else:
|
||||
return origin_name
|
||||
except (AttributeError, TypeError):
|
||||
# Fallback for older Python versions or non-generic types
|
||||
pass
|
||||
|
||||
# Handle generic types the old way for compatibility
|
||||
if hasattr(annotation, "__origin__") and hasattr(annotation, "__args__"):
|
||||
origin = annotation.__origin__
|
||||
origin_name = (
|
||||
origin.__name__
|
||||
if hasattr(origin, "__name__")
|
||||
else str(origin).split("'")[1]
|
||||
)
|
||||
|
||||
# Format each type argument
|
||||
args = []
|
||||
for arg in annotation.__args__:
|
||||
args.append(cls._format_type_annotation(arg, type_tracker))
|
||||
|
||||
return f"{origin_name}[{', '.join(args)}]"
|
||||
|
||||
# Handle regular types with __name__
|
||||
if hasattr(annotation, "__name__"):
|
||||
return annotation.__name__
|
||||
|
||||
# Handle special module types (like types from typing module)
|
||||
if hasattr(annotation, "__module__") and hasattr(annotation, "__qualname__"):
|
||||
# For types like typing.Literal, typing.TypedDict, etc.
|
||||
return annotation.__qualname__
|
||||
|
||||
# Last resort: string conversion with cleanup
|
||||
type_str = str(annotation)
|
||||
|
||||
# Clean up common patterns more robustly
|
||||
if type_str.startswith("<class '") and type_str.endswith("'>"):
|
||||
type_str = type_str[8:-2] # Remove "<class '" and "'>"
|
||||
|
||||
# Remove module prefixes for common modules
|
||||
for prefix in ["typing.", "builtins.", "types."]:
|
||||
if type_str.startswith(prefix):
|
||||
type_str = type_str[len(prefix) :]
|
||||
|
||||
# Handle special cases
|
||||
if type_str in ("_empty", "inspect._empty"):
|
||||
return "None"
|
||||
|
||||
# Fix NoneType (this should rarely be needed now)
|
||||
if type_str == "NoneType":
|
||||
return "None"
|
||||
|
||||
return type_str
|
||||
|
||||
@classmethod
|
||||
def _extract_coroutine_return_type(cls, annotation):
|
||||
"""Extract the actual return type from a Coroutine annotation."""
|
||||
if hasattr(annotation, "__args__") and len(annotation.__args__) > 2:
|
||||
# Coroutine[Any, Any, ReturnType] -> extract ReturnType
|
||||
return annotation.__args__[2]
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def _format_parameter_default(cls, default_value) -> str:
|
||||
"""Format a parameter's default value for stub files."""
|
||||
if default_value is inspect.Parameter.empty:
|
||||
return ""
|
||||
elif default_value is None:
|
||||
return " = None"
|
||||
elif isinstance(default_value, bool):
|
||||
return f" = {default_value}"
|
||||
elif default_value == {}:
|
||||
return " = {}"
|
||||
elif default_value == []:
|
||||
return " = []"
|
||||
else:
|
||||
return f" = {default_value}"
|
||||
|
||||
@classmethod
|
||||
def _format_method_parameters(
|
||||
cls,
|
||||
sig: inspect.Signature,
|
||||
skip_self: bool = True,
|
||||
type_hints: Optional[dict] = None,
|
||||
type_tracker: Optional[TypeTracker] = None,
|
||||
) -> str:
|
||||
"""Format method parameters for stub files."""
|
||||
params = []
|
||||
if type_hints is None:
|
||||
type_hints = {}
|
||||
|
||||
for i, (param_name, param) in enumerate(sig.parameters.items()):
|
||||
if i == 0 and param_name == "self" and skip_self:
|
||||
params.append("self")
|
||||
else:
|
||||
# Get type annotation from type hints if available, otherwise from signature
|
||||
annotation = type_hints.get(param_name, param.annotation)
|
||||
type_str = cls._format_type_annotation(annotation, type_tracker)
|
||||
|
||||
# Get default value
|
||||
default_str = cls._format_parameter_default(param.default)
|
||||
|
||||
# Combine parameter parts
|
||||
if annotation is inspect.Parameter.empty:
|
||||
params.append(f"{param_name}: Any{default_str}")
|
||||
else:
|
||||
params.append(f"{param_name}: {type_str}{default_str}")
|
||||
|
||||
return ", ".join(params)
|
||||
|
||||
@classmethod
|
||||
def _generate_method_signature(
|
||||
cls,
|
||||
method_name: str,
|
||||
method,
|
||||
is_async: bool = False,
|
||||
type_tracker: Optional[TypeTracker] = None,
|
||||
) -> str:
|
||||
"""Generate a complete method signature for stub files."""
|
||||
sig = inspect.signature(method)
|
||||
|
||||
# Try to get evaluated type hints to resolve string annotations
|
||||
try:
|
||||
from typing import get_type_hints
|
||||
type_hints = get_type_hints(method)
|
||||
except Exception:
|
||||
# Fallback to empty dict if we can't get type hints
|
||||
type_hints = {}
|
||||
|
||||
# For async methods, extract the actual return type
|
||||
return_annotation = type_hints.get('return', sig.return_annotation)
|
||||
if is_async and inspect.iscoroutinefunction(method):
|
||||
return_annotation = cls._extract_coroutine_return_type(return_annotation)
|
||||
|
||||
# Format parameters with type hints
|
||||
params_str = cls._format_method_parameters(sig, type_hints=type_hints, type_tracker=type_tracker)
|
||||
|
||||
# Format return type
|
||||
return_type = cls._format_type_annotation(return_annotation, type_tracker)
|
||||
if return_annotation is inspect.Signature.empty:
|
||||
return_type = "None"
|
||||
|
||||
return f"def {method_name}({params_str}) -> {return_type}: ..."
|
||||
|
||||
@classmethod
|
||||
def _generate_imports(
|
||||
cls, async_class: Type, type_tracker: TypeTracker
|
||||
) -> list[str]:
|
||||
"""Generate import statements for the stub file."""
|
||||
imports = []
|
||||
|
||||
# Add standard typing imports
|
||||
imports.append(
|
||||
"from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple"
|
||||
)
|
||||
|
||||
# Add imports from the original module
|
||||
if async_class.__module__ != "builtins":
|
||||
module = inspect.getmodule(async_class)
|
||||
additional_types = []
|
||||
|
||||
if module:
|
||||
# Check if module has __all__ defined
|
||||
module_all = getattr(module, "__all__", None)
|
||||
|
||||
for name, obj in sorted(inspect.getmembers(module)):
|
||||
if isinstance(obj, type):
|
||||
# Skip if __all__ is defined and this name isn't in it
|
||||
# unless it's already been tracked as used in type annotations
|
||||
if module_all is not None and name not in module_all:
|
||||
# Check if this type was actually used in annotations
|
||||
if name not in type_tracker.discovered_types:
|
||||
continue
|
||||
|
||||
# Check for NamedTuple
|
||||
if issubclass(obj, tuple) and hasattr(obj, "_fields"):
|
||||
additional_types.append(name)
|
||||
# Mark as already imported
|
||||
type_tracker.already_imported.add(name)
|
||||
# Check for Enum
|
||||
elif issubclass(obj, Enum) and name != "Enum":
|
||||
additional_types.append(name)
|
||||
# Mark as already imported
|
||||
type_tracker.already_imported.add(name)
|
||||
|
||||
if additional_types:
|
||||
type_imports = ", ".join([async_class.__name__] + additional_types)
|
||||
imports.append(f"from {async_class.__module__} import {type_imports}")
|
||||
else:
|
||||
imports.append(
|
||||
f"from {async_class.__module__} import {async_class.__name__}"
|
||||
)
|
||||
|
||||
# Add imports for all discovered types
|
||||
# Pass the main module name to avoid duplicate imports
|
||||
imports.extend(
|
||||
type_tracker.get_imports(main_module_name=async_class.__module__)
|
||||
)
|
||||
|
||||
# Add base module import if needed
|
||||
if hasattr(inspect.getmodule(async_class), "__name__"):
|
||||
module_name = inspect.getmodule(async_class).__name__
|
||||
if "." in module_name:
|
||||
base_module = module_name.split(".")[0]
|
||||
# Only add if not already importing from it
|
||||
if not any(imp.startswith(f"from {base_module}") for imp in imports):
|
||||
imports.append(f"import {base_module}")
|
||||
|
||||
return imports
|
||||
|
||||
@classmethod
|
||||
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
|
||||
"""Extract class attributes that are classes themselves."""
|
||||
class_attributes = []
|
||||
|
||||
# Look for class attributes that are classes
|
||||
for name, attr in sorted(inspect.getmembers(async_class)):
|
||||
if isinstance(attr, type) and not name.startswith("_"):
|
||||
class_attributes.append((name, attr))
|
||||
elif (
|
||||
hasattr(async_class, "__annotations__")
|
||||
and name in async_class.__annotations__
|
||||
):
|
||||
annotation = async_class.__annotations__[name]
|
||||
if isinstance(annotation, type):
|
||||
class_attributes.append((name, annotation))
|
||||
|
||||
return class_attributes
|
||||
|
||||
@classmethod
|
||||
def _generate_inner_class_stub(
|
||||
cls,
|
||||
name: str,
|
||||
attr: Type,
|
||||
indent: str = " ",
|
||||
type_tracker: Optional[TypeTracker] = None,
|
||||
) -> list[str]:
|
||||
"""Generate stub for an inner class."""
|
||||
stub_lines = []
|
||||
stub_lines.append(f"{indent}class {name}Sync:")
|
||||
|
||||
# Add docstring if available
|
||||
if hasattr(attr, "__doc__") and attr.__doc__:
|
||||
stub_lines.extend(
|
||||
cls._format_docstring_for_stub(attr.__doc__, f"{indent} ")
|
||||
)
|
||||
|
||||
# Add __init__ if it exists
|
||||
if hasattr(attr, "__init__"):
|
||||
try:
|
||||
init_method = getattr(attr, "__init__")
|
||||
init_sig = inspect.signature(init_method)
|
||||
|
||||
# Try to get type hints
|
||||
try:
|
||||
from typing import get_type_hints
|
||||
init_hints = get_type_hints(init_method)
|
||||
except Exception:
|
||||
init_hints = {}
|
||||
|
||||
# Format parameters
|
||||
params_str = cls._format_method_parameters(
|
||||
init_sig, type_hints=init_hints, type_tracker=type_tracker
|
||||
)
|
||||
# Add __init__ docstring if available (before the method)
|
||||
if hasattr(init_method, "__doc__") and init_method.__doc__:
|
||||
stub_lines.extend(
|
||||
cls._format_docstring_for_stub(
|
||||
init_method.__doc__, f"{indent} "
|
||||
)
|
||||
)
|
||||
stub_lines.append(
|
||||
f"{indent} def __init__({params_str}) -> None: ..."
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
stub_lines.append(
|
||||
f"{indent} def __init__(self, *args, **kwargs) -> None: ..."
|
||||
)
|
||||
|
||||
# Add methods to the inner class
|
||||
has_methods = False
|
||||
for method_name, method in sorted(
|
||||
inspect.getmembers(attr, predicate=inspect.isfunction)
|
||||
):
|
||||
if method_name.startswith("_"):
|
||||
continue
|
||||
|
||||
has_methods = True
|
||||
try:
|
||||
# Add method docstring if available (before the method signature)
|
||||
if method.__doc__:
|
||||
stub_lines.extend(
|
||||
cls._format_docstring_for_stub(method.__doc__, f"{indent} ")
|
||||
)
|
||||
|
||||
method_sig = cls._generate_method_signature(
|
||||
method_name, method, is_async=True, type_tracker=type_tracker
|
||||
)
|
||||
stub_lines.append(f"{indent} {method_sig}")
|
||||
except (ValueError, TypeError):
|
||||
stub_lines.append(
|
||||
f"{indent} def {method_name}(self, *args, **kwargs): ..."
|
||||
)
|
||||
|
||||
if not has_methods:
|
||||
stub_lines.append(f"{indent} pass")
|
||||
|
||||
return stub_lines
|
||||
|
||||
@classmethod
|
||||
def _format_docstring_for_stub(
|
||||
cls, docstring: str, indent: str = " "
|
||||
) -> list[str]:
|
||||
"""Format a docstring for inclusion in a stub file with proper indentation."""
|
||||
if not docstring:
|
||||
return []
|
||||
|
||||
# First, dedent the docstring to remove any existing indentation
|
||||
dedented = textwrap.dedent(docstring).strip()
|
||||
|
||||
# Split into lines
|
||||
lines = dedented.split("\n")
|
||||
|
||||
# Build the properly indented docstring
|
||||
result = []
|
||||
result.append(f'{indent}"""')
|
||||
|
||||
for line in lines:
|
||||
if line.strip(): # Non-empty line
|
||||
result.append(f"{indent}{line}")
|
||||
else: # Empty line
|
||||
result.append("")
|
||||
|
||||
result.append(f'{indent}"""')
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _post_process_stub_content(cls, stub_content: list[str]) -> list[str]:
|
||||
"""Post-process stub content to fix any remaining issues."""
|
||||
processed = []
|
||||
|
||||
for line in stub_content:
|
||||
# Skip processing imports
|
||||
if line.startswith(("from ", "import ")):
|
||||
processed.append(line)
|
||||
continue
|
||||
|
||||
# Fix method signatures missing return types
|
||||
if (
|
||||
line.strip().startswith("def ")
|
||||
and line.strip().endswith(": ...")
|
||||
and ") -> " not in line
|
||||
):
|
||||
# Add -> None for methods without return annotation
|
||||
line = line.replace(": ...", " -> None: ...")
|
||||
|
||||
processed.append(line)
|
||||
|
||||
return processed
|
||||
|
||||
@classmethod
|
||||
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
|
||||
"""
|
||||
Generate a .pyi stub file for the sync class to help IDEs with type checking.
|
||||
"""
|
||||
try:
|
||||
# Only generate stub if we can determine module path
|
||||
if async_class.__module__ == "__main__":
|
||||
return
|
||||
|
||||
module = inspect.getmodule(async_class)
|
||||
if not module:
|
||||
return
|
||||
|
||||
module_path = module.__file__
|
||||
if not module_path:
|
||||
return
|
||||
|
||||
# Create stub file path in a 'generated' subdirectory
|
||||
module_dir = os.path.dirname(module_path)
|
||||
stub_dir = os.path.join(module_dir, "generated")
|
||||
|
||||
# Ensure the generated directory exists
|
||||
os.makedirs(stub_dir, exist_ok=True)
|
||||
|
||||
module_name = os.path.basename(module_path)
|
||||
if module_name.endswith(".py"):
|
||||
module_name = module_name[:-3]
|
||||
|
||||
sync_stub_path = os.path.join(stub_dir, f"{sync_class.__name__}.pyi")
|
||||
|
||||
# Create a type tracker for this stub generation
|
||||
type_tracker = TypeTracker()
|
||||
|
||||
stub_content = []
|
||||
|
||||
# We'll generate imports after processing all methods to capture all types
|
||||
# Leave a placeholder for imports
|
||||
imports_placeholder_index = len(stub_content)
|
||||
stub_content.append("") # Will be replaced with imports later
|
||||
|
||||
# Class definition
|
||||
stub_content.append(f"class {sync_class.__name__}:")
|
||||
|
||||
# Docstring
|
||||
if async_class.__doc__:
|
||||
stub_content.extend(
|
||||
cls._format_docstring_for_stub(async_class.__doc__, " ")
|
||||
)
|
||||
|
||||
# Generate __init__
|
||||
try:
|
||||
init_method = async_class.__init__
|
||||
init_signature = inspect.signature(init_method)
|
||||
|
||||
# Try to get type hints for __init__
|
||||
try:
|
||||
from typing import get_type_hints
|
||||
init_hints = get_type_hints(init_method)
|
||||
except Exception:
|
||||
init_hints = {}
|
||||
|
||||
# Format parameters
|
||||
params_str = cls._format_method_parameters(
|
||||
init_signature, type_hints=init_hints, type_tracker=type_tracker
|
||||
)
|
||||
# Add __init__ docstring if available (before the method)
|
||||
if hasattr(init_method, "__doc__") and init_method.__doc__:
|
||||
stub_content.extend(
|
||||
cls._format_docstring_for_stub(init_method.__doc__, " ")
|
||||
)
|
||||
stub_content.append(f" def __init__({params_str}) -> None: ...")
|
||||
except (ValueError, TypeError):
|
||||
stub_content.append(
|
||||
" def __init__(self, *args, **kwargs) -> None: ..."
|
||||
)
|
||||
|
||||
stub_content.append("") # Add newline after __init__
|
||||
|
||||
# Get class attributes
|
||||
class_attributes = cls._get_class_attributes(async_class)
|
||||
|
||||
# Generate inner classes
|
||||
for name, attr in class_attributes:
|
||||
inner_class_stub = cls._generate_inner_class_stub(
|
||||
name, attr, type_tracker=type_tracker
|
||||
)
|
||||
stub_content.extend(inner_class_stub)
|
||||
stub_content.append("") # Add newline after the inner class
|
||||
|
||||
# Add methods to the main class
|
||||
processed_methods = set() # Keep track of methods we've processed
|
||||
for name, method in sorted(
|
||||
inspect.getmembers(async_class, predicate=inspect.isfunction)
|
||||
):
|
||||
if name.startswith("_") or name in processed_methods:
|
||||
continue
|
||||
|
||||
processed_methods.add(name)
|
||||
|
||||
try:
|
||||
method_sig = cls._generate_method_signature(
|
||||
name, method, is_async=True, type_tracker=type_tracker
|
||||
)
|
||||
|
||||
# Add docstring if available (before the method signature for proper formatting)
|
||||
if method.__doc__:
|
||||
stub_content.extend(
|
||||
cls._format_docstring_for_stub(method.__doc__, " ")
|
||||
)
|
||||
|
||||
stub_content.append(f" {method_sig}")
|
||||
|
||||
stub_content.append("") # Add newline after each method
|
||||
|
||||
except (ValueError, TypeError):
|
||||
# If we can't get the signature, just add a simple stub
|
||||
stub_content.append(f" def {name}(self, *args, **kwargs): ...")
|
||||
stub_content.append("") # Add newline
|
||||
|
||||
# Add properties
|
||||
for name, prop in sorted(
|
||||
inspect.getmembers(async_class, lambda x: isinstance(x, property))
|
||||
):
|
||||
stub_content.append(" @property")
|
||||
stub_content.append(f" def {name}(self) -> Any: ...")
|
||||
if prop.fset:
|
||||
stub_content.append(f" @{name}.setter")
|
||||
stub_content.append(
|
||||
f" def {name}(self, value: Any) -> None: ..."
|
||||
)
|
||||
stub_content.append("") # Add newline after each property
|
||||
|
||||
# Add placeholders for the nested class instances
|
||||
# Check the actual attribute names from class annotations and attributes
|
||||
attribute_mappings = {}
|
||||
|
||||
# First check annotations for typed attributes (including from parent classes)
|
||||
# Collect all annotations from the class hierarchy
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
all_annotations.update(base_class.__annotations__)
|
||||
|
||||
for attr_name, attr_type in sorted(all_annotations.items()):
|
||||
for class_name, class_type in class_attributes:
|
||||
# If the class type matches the annotated type
|
||||
if (
|
||||
attr_type == class_type
|
||||
or (hasattr(attr_type, "__name__") and attr_type.__name__ == class_name)
|
||||
or (isinstance(attr_type, str) and attr_type == class_name)
|
||||
):
|
||||
attribute_mappings[class_name] = attr_name
|
||||
|
||||
# Remove the extra checking - annotations should be sufficient
|
||||
|
||||
# Add the attribute declarations with proper names
|
||||
for class_name, class_type in class_attributes:
|
||||
# Check if there's a mapping from annotation
|
||||
attr_name = attribute_mappings.get(class_name, class_name)
|
||||
# Use the annotation name if it exists, even if the attribute doesn't exist yet
|
||||
# This is because the attribute might be created at runtime
|
||||
stub_content.append(f" {attr_name}: {class_name}Sync")
|
||||
|
||||
stub_content.append("") # Add a final newline
|
||||
|
||||
# Now generate imports with all discovered types
|
||||
imports = cls._generate_imports(async_class, type_tracker)
|
||||
|
||||
# Deduplicate imports while preserving order
|
||||
seen = set()
|
||||
unique_imports = []
|
||||
for imp in imports:
|
||||
if imp not in seen:
|
||||
seen.add(imp)
|
||||
unique_imports.append(imp)
|
||||
else:
|
||||
logging.warning(f"Duplicate import detected: {imp}")
|
||||
|
||||
# Replace the placeholder with actual imports
|
||||
stub_content[imports_placeholder_index : imports_placeholder_index + 1] = (
|
||||
unique_imports
|
||||
)
|
||||
|
||||
# Post-process stub content
|
||||
stub_content = cls._post_process_stub_content(stub_content)
|
||||
|
||||
# Write stub file
|
||||
with open(sync_stub_path, "w") as f:
|
||||
f.write("\n".join(stub_content))
|
||||
|
||||
logging.info(f"Generated stub file: {sync_stub_path}")
|
||||
|
||||
except Exception as e:
|
||||
# If stub generation fails, log the error but don't break the main functionality
|
||||
logging.error(
|
||||
f"Error generating stub file for {sync_class.__name__}: {str(e)}"
|
||||
)
|
||||
import traceback
|
||||
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
|
||||
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
|
||||
"""
|
||||
Creates a sync version of an async class
|
||||
|
||||
Args:
|
||||
async_class: The async class to convert
|
||||
thread_pool_size: Size of thread pool to use
|
||||
|
||||
Returns:
|
||||
A new class with sync versions of all async methods
|
||||
"""
|
||||
return AsyncToSyncConverter.create_sync_class(async_class, thread_pool_size)
|
||||
33
comfy_api/internal/singleton.py
Normal file
33
comfy_api/internal/singleton.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Type, TypeVar
|
||||
|
||||
class SingletonMetaclass(type):
|
||||
T = TypeVar("T", bound="SingletonMetaclass")
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(SingletonMetaclass, cls).__call__(
|
||||
*args, **kwargs
|
||||
)
|
||||
return cls._instances[cls]
|
||||
|
||||
def inject_instance(cls: Type[T], instance: T) -> None:
|
||||
assert cls not in SingletonMetaclass._instances, (
|
||||
"Cannot inject instance after first instantiation"
|
||||
)
|
||||
SingletonMetaclass._instances[cls] = instance
|
||||
|
||||
def get_instance(cls: Type[T], *args, **kwargs) -> T:
|
||||
"""
|
||||
Gets the singleton instance of the class, creating it if it doesn't exist.
|
||||
"""
|
||||
if cls not in SingletonMetaclass._instances:
|
||||
SingletonMetaclass._instances[cls] = super(
|
||||
SingletonMetaclass, cls
|
||||
).__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
class ProxiedSingleton(object, metaclass=SingletonMetaclass):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
124
comfy_api/latest/__init__.py
Normal file
124
comfy_api/latest/__init__.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal import ComfyAPIBase
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
||||
from comfy_api.latest._io import _IO as io #noqa: F401
|
||||
from comfy_api.latest._ui import _UI as ui #noqa: F401
|
||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||
from PIL import Image
|
||||
from comfy.cli_args import args
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ComfyAPI_latest(ComfyAPIBase):
|
||||
VERSION = "latest"
|
||||
STABLE = False
|
||||
|
||||
class Execution(ProxiedSingleton):
|
||||
async def set_progress(
|
||||
self,
|
||||
value: float,
|
||||
max_value: float,
|
||||
node_id: str | None = None,
|
||||
preview_image: Image.Image | ImageInput | None = None,
|
||||
ignore_size_limit: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Update the progress bar displayed in the ComfyUI interface.
|
||||
|
||||
This function allows custom nodes and API calls to report their progress
|
||||
back to the user interface, providing visual feedback during long operations.
|
||||
|
||||
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||
"""
|
||||
executing_context = get_executing_context()
|
||||
if node_id is None and executing_context is not None:
|
||||
node_id = executing_context.node_id
|
||||
if node_id is None:
|
||||
raise ValueError("node_id must be provided if not in executing context")
|
||||
|
||||
# Convert preview_image to PreviewImageTuple if needed
|
||||
to_display: PreviewImageTuple | Image.Image | ImageInput | None = preview_image
|
||||
if to_display is not None:
|
||||
# First convert to PIL Image if needed
|
||||
if isinstance(to_display, ImageInput):
|
||||
# Convert ImageInput (torch.Tensor) to PIL Image
|
||||
# Handle tensor shape [B, H, W, C] -> get first image if batch
|
||||
tensor = to_display
|
||||
if len(tensor.shape) == 4:
|
||||
tensor = tensor[0]
|
||||
|
||||
# Convert to numpy array and scale to 0-255
|
||||
image_np = (tensor.cpu().numpy() * 255).astype(np.uint8)
|
||||
to_display = Image.fromarray(image_np)
|
||||
|
||||
if isinstance(to_display, Image.Image):
|
||||
# Detect image format from PIL Image
|
||||
image_format = to_display.format if to_display.format else "JPEG"
|
||||
# Use None for preview_size if ignore_size_limit is True
|
||||
preview_size = None if ignore_size_limit else args.preview_size
|
||||
to_display = (image_format, to_display, preview_size)
|
||||
|
||||
get_progress_state().update_progress(
|
||||
node_id=node_id,
|
||||
value=value,
|
||||
max_value=max_value,
|
||||
image=to_display,
|
||||
)
|
||||
|
||||
execution: Execution
|
||||
|
||||
class ComfyExtension(ABC):
|
||||
async def on_load(self) -> None:
|
||||
"""
|
||||
Called when an extension is loaded.
|
||||
This should be used to initialize any global resources neeeded by the extension.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
"""
|
||||
Returns a list of nodes that this extension provides.
|
||||
"""
|
||||
|
||||
class Input:
|
||||
Image = ImageInput
|
||||
Audio = AudioInput
|
||||
Mask = MaskInput
|
||||
Latent = LatentInput
|
||||
Video = VideoInput
|
||||
|
||||
class InputImpl:
|
||||
VideoFromFile = VideoFromFile
|
||||
VideoFromComponents = VideoFromComponents
|
||||
|
||||
class Types:
|
||||
VideoCodec = VideoCodec
|
||||
VideoContainer = VideoContainer
|
||||
VideoComponents = VideoComponents
|
||||
|
||||
ComfyAPI = ComfyAPI_latest
|
||||
|
||||
# Create a synchronous version of the API
|
||||
if TYPE_CHECKING:
|
||||
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
|
||||
|
||||
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
||||
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
||||
|
||||
__all__ = [
|
||||
"ComfyAPI",
|
||||
"ComfyAPISync",
|
||||
"Input",
|
||||
"InputImpl",
|
||||
"Types",
|
||||
"ComfyExtension",
|
||||
]
|
||||
10
comfy_api/latest/_input/__init__.py
Normal file
10
comfy_api/latest/_input/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||
from .video_types import VideoInput
|
||||
|
||||
__all__ = [
|
||||
"ImageInput",
|
||||
"AudioInput",
|
||||
"VideoInput",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
]
|
||||
42
comfy_api/latest/_input/basic_types.py
Normal file
42
comfy_api/latest/_input/basic_types.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
from typing import TypedDict, List, Optional
|
||||
|
||||
ImageInput = torch.Tensor
|
||||
"""
|
||||
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
|
||||
"""
|
||||
|
||||
MaskInput = torch.Tensor
|
||||
"""
|
||||
A mask in format [B, H, W] where B is the batch size
|
||||
"""
|
||||
|
||||
class AudioInput(TypedDict):
|
||||
"""
|
||||
TypedDict representing audio input.
|
||||
"""
|
||||
|
||||
waveform: torch.Tensor
|
||||
"""
|
||||
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
|
||||
"""
|
||||
|
||||
sample_rate: int
|
||||
|
||||
class LatentInput(TypedDict):
|
||||
"""
|
||||
TypedDict representing latent input.
|
||||
"""
|
||||
|
||||
samples: torch.Tensor
|
||||
"""
|
||||
Tensor in the format [B, C, H, W] where B is the batch size, C is the number of channels,
|
||||
H is the height, and W is the width.
|
||||
"""
|
||||
|
||||
noise_mask: Optional[MaskInput]
|
||||
"""
|
||||
Optional noise mask tensor in the same format as samples.
|
||||
"""
|
||||
|
||||
batch_index: Optional[List[int]]
|
||||
85
comfy_api/latest/_input/video_types.py
Normal file
85
comfy_api/latest/_input/video_types.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
import io
|
||||
import av
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
class VideoInput(ABC):
|
||||
"""
|
||||
Abstract base class for video input types.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_components(self) -> VideoComponents:
|
||||
"""
|
||||
Abstract method to get the video components (images, audio, and frame rate).
|
||||
|
||||
Returns:
|
||||
VideoComponents containing images, audio, and frame rate
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
Abstract method to save the video input to a file.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
||||
"""
|
||||
Get a streamable source for the video. This allows processing without
|
||||
loading the entire video into memory.
|
||||
|
||||
Returns:
|
||||
Either a file path (str) or a BytesIO object that can be opened with av.
|
||||
|
||||
Default implementation creates a BytesIO buffer, but subclasses should
|
||||
override this for better performance when possible.
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
self.save_to(buffer)
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
# Provide a default implementation, but subclasses can provide optimized versions
|
||||
# if possible.
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the dimensions of the video input.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
components = self.get_components()
|
||||
return components.images.shape[2], components.images.shape[1]
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
components = self.get_components()
|
||||
frame_count = components.images.shape[0]
|
||||
return float(frame_count / components.frame_rate)
|
||||
|
||||
def get_container_format(self) -> str:
|
||||
"""
|
||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||
|
||||
Returns:
|
||||
Container format as string
|
||||
"""
|
||||
# Default implementation - subclasses should override for better performance
|
||||
source = self.get_stream_source()
|
||||
with av.open(source, mode="r") as container:
|
||||
return container.format.name
|
||||
7
comfy_api/latest/_input_impl/__init__.py
Normal file
7
comfy_api/latest/_input_impl/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .video_types import VideoFromFile, VideoFromComponents
|
||||
|
||||
__all__ = [
|
||||
# Implementations
|
||||
"VideoFromFile",
|
||||
"VideoFromComponents",
|
||||
]
|
||||
324
comfy_api/latest/_input_impl/video_types.py
Normal file
324
comfy_api/latest/_input_impl/video_types.py
Normal file
@@ -0,0 +1,324 @@
|
||||
from __future__ import annotations
|
||||
from av.container import InputContainer
|
||||
from av.subtitles.stream import SubtitleStream
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.latest._input import AudioInput, VideoInput
|
||||
import av
|
||||
import io
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
def container_to_output_format(container_format: str | None) -> str | None:
|
||||
"""
|
||||
A container's `format` may be a comma-separated list of formats.
|
||||
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
|
||||
However, writing to a file/stream with `av.open` requires a single format,
|
||||
or `None` to auto-detect.
|
||||
"""
|
||||
if not container_format:
|
||||
return None # Auto-detect
|
||||
|
||||
if "," not in container_format:
|
||||
return container_format
|
||||
|
||||
formats = container_format.split(",")
|
||||
return formats[0]
|
||||
|
||||
|
||||
def get_open_write_kwargs(
|
||||
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||
) -> dict:
|
||||
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
|
||||
open_kwargs = {
|
||||
"mode": "w",
|
||||
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
|
||||
"options": {"movflags": "use_metadata_tags"},
|
||||
}
|
||||
|
||||
is_write_to_buffer = isinstance(dest, io.BytesIO)
|
||||
if is_write_to_buffer:
|
||||
# Set output format explicitly, since it cannot be inferred from file extension
|
||||
if to_format == VideoContainer.AUTO:
|
||||
to_format = container_format.lower()
|
||||
elif isinstance(to_format, str):
|
||||
to_format = to_format.lower()
|
||||
open_kwargs["format"] = container_to_output_format(to_format)
|
||||
|
||||
return open_kwargs
|
||||
|
||||
|
||||
class VideoFromFile(VideoInput):
|
||||
"""
|
||||
Class representing video input from a file.
|
||||
"""
|
||||
|
||||
def __init__(self, file: str | io.BytesIO):
|
||||
"""
|
||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||
containing the file contents.
|
||||
"""
|
||||
self.__file = file
|
||||
|
||||
def get_stream_source(self) -> str | io.BytesIO:
|
||||
"""
|
||||
Return the underlying file source for efficient streaming.
|
||||
This avoids unnecessary memory copies when the source is already a file path.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
return self.__file
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the dimensions of the video input.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
for stream in container.streams:
|
||||
if stream.type == 'video':
|
||||
assert isinstance(stream, av.VideoStream)
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
if container.duration is not None:
|
||||
return float(container.duration / av.time_base)
|
||||
|
||||
# Fallback: calculate from frame count and frame rate
|
||||
video_stream = next(
|
||||
(s for s in container.streams if s.type == "video"), None
|
||||
)
|
||||
if video_stream and video_stream.frames and video_stream.average_rate:
|
||||
return float(video_stream.frames / video_stream.average_rate)
|
||||
|
||||
# Last resort: decode frames to count them
|
||||
if video_stream and video_stream.average_rate:
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
if frame_count > 0:
|
||||
return float(frame_count / video_stream.average_rate)
|
||||
|
||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||
|
||||
def get_container_format(self) -> str:
|
||||
"""
|
||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||
|
||||
Returns:
|
||||
Container format as string
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
return container.format.name
|
||||
|
||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||
# Get video frames
|
||||
frames = []
|
||||
for frame in container.decode(video=0):
|
||||
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
||||
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
||||
frames.append(img)
|
||||
|
||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
||||
|
||||
# Get frame rate
|
||||
video_stream = next(s for s in container.streams if s.type == 'video')
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
||||
|
||||
# Get audio if available
|
||||
audio = None
|
||||
try:
|
||||
container.seek(0) # Reset the container to the beginning
|
||||
for stream in container.streams:
|
||||
if stream.type != 'audio':
|
||||
continue
|
||||
assert isinstance(stream, av.AudioStream)
|
||||
audio_frames = []
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
assert isinstance(frame, av.AudioFrame)
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
||||
})
|
||||
except StopIteration:
|
||||
pass # No audio stream
|
||||
|
||||
metadata = container.metadata
|
||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
return self.get_components_internal(container)
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str | io.BytesIO,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
container_format = container.format.name
|
||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||
reuse_streams = True
|
||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||
reuse_streams = False
|
||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||
reuse_streams = False
|
||||
|
||||
if not reuse_streams:
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path,
|
||||
format=format,
|
||||
codec=codec,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
|
||||
open_kwargs = get_open_write_kwargs(path, container_format, format)
|
||||
with av.open(path, **open_kwargs) as output_container:
|
||||
# Copy over the original metadata
|
||||
for key, value in container.metadata.items():
|
||||
if metadata is None or key not in metadata:
|
||||
output_container.metadata[key] = value
|
||||
|
||||
# Add our new metadata
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, str):
|
||||
output_container.metadata[key] = value
|
||||
else:
|
||||
output_container.metadata[key] = json.dumps(value)
|
||||
|
||||
# Add streams to the new container
|
||||
stream_map = {}
|
||||
for stream in streams:
|
||||
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
|
||||
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
|
||||
stream_map[stream] = out_stream
|
||||
|
||||
# Write packets to the new container
|
||||
for packet in container.demux():
|
||||
if packet.stream in stream_map and packet.dts is not None:
|
||||
packet.stream = stream_map[packet.stream]
|
||||
output_container.mux(packet)
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
"""
|
||||
Class representing video input from tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, components: VideoComponents):
|
||||
self.__components = components
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
return VideoComponents(
|
||||
images=self.__components.images,
|
||||
audio=self.__components.audio,
|
||||
frame_rate=self.__components.frame_rate
|
||||
)
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
raise ValueError("Only H264 codec is supported for now")
|
||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
||||
# Add metadata before writing any streams
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
output.metadata[key] = json.dumps(value)
|
||||
|
||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||
# Create a video stream
|
||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||
video_stream.width = self.__components.images.shape[2]
|
||||
video_stream.height = self.__components.images.shape[1]
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
|
||||
# Create an audio stream
|
||||
audio_sample_rate = 1
|
||||
audio_stream: Optional[av.AudioStream] = None
|
||||
if self.__components.audio:
|
||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||
audio_stream.sample_rate = audio_sample_rate
|
||||
audio_stream.format = 'fltp'
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||
packet = video_stream.encode(frame)
|
||||
output.mux(packet)
|
||||
|
||||
# Flush video
|
||||
packet = video_stream.encode(None)
|
||||
output.mux(packet)
|
||||
|
||||
if audio_stream and self.__components.audio:
|
||||
# Encode audio
|
||||
samples_per_frame = int(audio_sample_rate / frame_rate)
|
||||
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
|
||||
for i in range(num_frames):
|
||||
start = i * samples_per_frame
|
||||
end = start + samples_per_frame
|
||||
# TODO(Feature) - Add support for stereo audio
|
||||
chunk = (
|
||||
self.__components.audio["waveform"][0, 0, start:end]
|
||||
.unsqueeze(0)
|
||||
.contiguous()
|
||||
.numpy()
|
||||
)
|
||||
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
|
||||
audio_frame.sample_rate = audio_sample_rate
|
||||
audio_frame.pts = i * samples_per_frame
|
||||
for packet in audio_stream.encode(audio_frame):
|
||||
output.mux(packet)
|
||||
|
||||
# Flush audio
|
||||
for packet in audio_stream.encode(None):
|
||||
output.mux(packet)
|
||||
|
||||
|
||||
1618
comfy_api/latest/_io.py
Normal file
1618
comfy_api/latest/_io.py
Normal file
File diff suppressed because it is too large
Load Diff
72
comfy_api/latest/_resources.py
Normal file
72
comfy_api/latest/_resources.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
class ResourceKey(ABC):
|
||||
Type = Any
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
class TorchDictFolderFilename(ResourceKey):
|
||||
'''Key for requesting a torch file via file_name from a folder category.'''
|
||||
Type = dict[str, torch.Tensor]
|
||||
def __init__(self, folder_name: str, file_name: str):
|
||||
self.folder_name = folder_name
|
||||
self.file_name = file_name
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.folder_name, self.file_name))
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, TorchDictFolderFilename):
|
||||
return False
|
||||
return self.folder_name == other.folder_name and self.file_name == other.file_name
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.folder_name} -> {self.file_name}"
|
||||
|
||||
class Resources(ABC):
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||
pass
|
||||
|
||||
class ResourcesLocal(Resources):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.local_resources: dict[ResourceKey, Any] = {}
|
||||
|
||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||
cached = self.local_resources.get(key, None)
|
||||
if cached is not None:
|
||||
logging.info(f"Using cached resource '{key}'")
|
||||
return cached
|
||||
logging.info(f"Loading resource '{key}'")
|
||||
to_return = None
|
||||
if isinstance(key, TorchDictFolderFilename):
|
||||
if default is ...:
|
||||
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
|
||||
else:
|
||||
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
|
||||
if full_path is not None:
|
||||
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
|
||||
|
||||
if to_return is not None:
|
||||
self.local_resources[key] = to_return
|
||||
return to_return
|
||||
if default is not ...:
|
||||
return default
|
||||
raise Exception(f"Unsupported resource key type: {type(key)}")
|
||||
|
||||
|
||||
class _RESOURCES:
|
||||
ResourceKey = ResourceKey
|
||||
TorchDictFolderFilename = TorchDictFolderFilename
|
||||
Resources = Resources
|
||||
ResourcesLocal = ResourcesLocal
|
||||
457
comfy_api/latest/_ui.py
Normal file
457
comfy_api/latest/_ui.py
Normal file
@@ -0,0 +1,457 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from io import BytesIO
|
||||
from typing import Type
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from PIL import Image as PILImage
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
import folder_paths
|
||||
|
||||
# used for image preview
|
||||
from comfy.cli_args import args
|
||||
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
|
||||
|
||||
|
||||
class SavedResult(dict):
|
||||
def __init__(self, filename: str, subfolder: str, type: FolderType):
|
||||
super().__init__(filename=filename, subfolder=subfolder,type=type.value)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self["filename"]
|
||||
|
||||
@property
|
||||
def subfolder(self) -> str:
|
||||
return self["subfolder"]
|
||||
|
||||
@property
|
||||
def type(self) -> FolderType:
|
||||
return FolderType(self["type"])
|
||||
|
||||
|
||||
class SavedImages(_UIOutput):
|
||||
"""A UI output class to represent one or more saved images, potentially animated."""
|
||||
def __init__(self, results: list[SavedResult], is_animated: bool = False):
|
||||
super().__init__()
|
||||
self.results = results
|
||||
self.is_animated = is_animated
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
data = {"images": self.results}
|
||||
if self.is_animated:
|
||||
data["animated"] = (True,)
|
||||
return data
|
||||
|
||||
|
||||
class SavedAudios(_UIOutput):
|
||||
"""UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus)."""
|
||||
def __init__(self, results: list[SavedResult]):
|
||||
super().__init__()
|
||||
self.results = results
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
return {"audio": self.results}
|
||||
|
||||
|
||||
def _get_directory_by_folder_type(folder_type: FolderType) -> str:
|
||||
if folder_type == FolderType.input:
|
||||
return folder_paths.get_input_directory()
|
||||
if folder_type == FolderType.output:
|
||||
return folder_paths.get_output_directory()
|
||||
return folder_paths.get_temp_directory()
|
||||
|
||||
|
||||
class ImageSaveHelper:
|
||||
"""A helper class with static methods to handle image saving and metadata."""
|
||||
|
||||
@staticmethod
|
||||
def _convert_tensor_to_pil(image_tensor: torch.Tensor) -> PILImage.Image:
|
||||
"""Converts a single torch tensor to a PIL Image."""
|
||||
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
|
||||
|
||||
@staticmethod
|
||||
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||
"""Creates a PngInfo object with prompt and extra_pnginfo."""
|
||||
if args.disable_metadata or cls is None or not cls.hidden:
|
||||
return None
|
||||
metadata = PngInfo()
|
||||
if cls.hidden.prompt:
|
||||
metadata.add_text("prompt", json.dumps(cls.hidden.prompt))
|
||||
if cls.hidden.extra_pnginfo:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata.add_text(x, json.dumps(cls.hidden.extra_pnginfo[x]))
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
|
||||
if args.disable_metadata or cls is None or not cls.hidden:
|
||||
return None
|
||||
metadata = PngInfo()
|
||||
if cls.hidden.prompt:
|
||||
metadata.add(
|
||||
b"comf",
|
||||
"prompt".encode("latin-1", "strict")
|
||||
+ b"\0"
|
||||
+ json.dumps(cls.hidden.prompt).encode("latin-1", "strict"),
|
||||
after_idat=True,
|
||||
)
|
||||
if cls.hidden.extra_pnginfo:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata.add(
|
||||
b"comf",
|
||||
x.encode("latin-1", "strict")
|
||||
+ b"\0"
|
||||
+ json.dumps(cls.hidden.extra_pnginfo[x]).encode("latin-1", "strict"),
|
||||
after_idat=True,
|
||||
)
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
|
||||
"""Creates EXIF metadata bytes for WebP images."""
|
||||
exif_data = pil_image.getexif()
|
||||
if args.disable_metadata or cls is None or cls.hidden is None:
|
||||
return exif_data
|
||||
if cls.hidden.prompt is not None:
|
||||
exif_data[0x0110] = "prompt:{}".format(json.dumps(cls.hidden.prompt)) # EXIF 0x0110 = Model
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
inital_exif_tag = 0x010F # EXIF 0x010f = Make
|
||||
for key, value in cls.hidden.extra_pnginfo.items():
|
||||
exif_data[inital_exif_tag] = "{}:{}".format(key, json.dumps(value))
|
||||
inital_exif_tag -= 1
|
||||
return exif_data
|
||||
|
||||
@staticmethod
|
||||
def save_images(
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
|
||||
) -> list[SavedResult]:
|
||||
"""Saves a batch of images as individual PNG files."""
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
|
||||
)
|
||||
results = []
|
||||
metadata = ImageSaveHelper._create_png_metadata(cls)
|
||||
for batch_number, image_tensor in enumerate(images):
|
||||
img = ImageSaveHelper._convert_tensor_to_pil(image_tensor)
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.png"
|
||||
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level)
|
||||
results.append(SavedResult(file, subfolder, folder_type))
|
||||
counter += 1
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
|
||||
"""Saves a batch of images and returns a UI object for the node output."""
|
||||
return SavedImages(
|
||||
ImageSaveHelper.save_images(
|
||||
images,
|
||||
filename_prefix=filename_prefix,
|
||||
folder_type=FolderType.output,
|
||||
cls=cls,
|
||||
compress_level=compress_level,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def save_animated_png(
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||
) -> SavedResult:
|
||||
"""Saves a batch of images as a single animated PNG."""
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
|
||||
)
|
||||
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
|
||||
metadata = ImageSaveHelper._create_animated_png_metadata(cls)
|
||||
file = f"{filename}_{counter:05}_.png"
|
||||
save_path = os.path.join(full_output_folder, file)
|
||||
pil_images[0].save(
|
||||
save_path,
|
||||
pnginfo=metadata,
|
||||
compress_level=compress_level,
|
||||
save_all=True,
|
||||
duration=int(1000.0 / fps),
|
||||
append_images=pil_images[1:],
|
||||
)
|
||||
return SavedResult(file, subfolder, folder_type)
|
||||
|
||||
@staticmethod
|
||||
def get_save_animated_png_ui(
|
||||
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||
) -> SavedImages:
|
||||
"""Saves an animated PNG and returns a UI object for the node output."""
|
||||
result = ImageSaveHelper.save_animated_png(
|
||||
images,
|
||||
filename_prefix=filename_prefix,
|
||||
folder_type=FolderType.output,
|
||||
cls=cls,
|
||||
fps=fps,
|
||||
compress_level=compress_level,
|
||||
)
|
||||
return SavedImages([result], is_animated=len(images) > 1)
|
||||
|
||||
@staticmethod
|
||||
def save_animated_webp(
|
||||
images,
|
||||
filename_prefix: str,
|
||||
folder_type: FolderType,
|
||||
cls: Type[ComfyNode] | None,
|
||||
fps: float,
|
||||
lossless: bool,
|
||||
quality: int,
|
||||
method: int,
|
||||
) -> SavedResult:
|
||||
"""Saves a batch of images as a single animated WebP."""
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
|
||||
)
|
||||
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
|
||||
pil_exif = ImageSaveHelper._create_webp_metadata(pil_images[0], cls)
|
||||
file = f"{filename}_{counter:05}_.webp"
|
||||
pil_images[0].save(
|
||||
os.path.join(full_output_folder, file),
|
||||
save_all=True,
|
||||
duration=int(1000.0 / fps),
|
||||
append_images=pil_images[1:],
|
||||
exif=pil_exif,
|
||||
lossless=lossless,
|
||||
quality=quality,
|
||||
method=method,
|
||||
)
|
||||
return SavedResult(file, subfolder, folder_type)
|
||||
|
||||
@staticmethod
|
||||
def get_save_animated_webp_ui(
|
||||
images,
|
||||
filename_prefix: str,
|
||||
cls: Type[ComfyNode] | None,
|
||||
fps: float,
|
||||
lossless: bool,
|
||||
quality: int,
|
||||
method: int,
|
||||
) -> SavedImages:
|
||||
"""Saves an animated WebP and returns a UI object for the node output."""
|
||||
result = ImageSaveHelper.save_animated_webp(
|
||||
images,
|
||||
filename_prefix=filename_prefix,
|
||||
folder_type=FolderType.output,
|
||||
cls=cls,
|
||||
fps=fps,
|
||||
lossless=lossless,
|
||||
quality=quality,
|
||||
method=method,
|
||||
)
|
||||
return SavedImages([result], is_animated=len(images) > 1)
|
||||
|
||||
|
||||
class AudioSaveHelper:
|
||||
"""A helper class with static methods to handle audio saving and metadata."""
|
||||
_OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
|
||||
|
||||
@staticmethod
|
||||
def save_audio(
|
||||
audio: dict,
|
||||
filename_prefix: str,
|
||||
folder_type: FolderType,
|
||||
cls: Type[ComfyNode] | None,
|
||||
format: str = "flac",
|
||||
quality: str = "128k",
|
||||
) -> list[SavedResult]:
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, _get_directory_by_folder_type(folder_type)
|
||||
)
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata and cls is not None:
|
||||
if cls.hidden.prompt is not None:
|
||||
metadata["prompt"] = json.dumps(cls.hidden.prompt)
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||
|
||||
results = []
|
||||
for batch_number, waveform in enumerate(audio["waveform"].cpu()):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
|
||||
output_path = os.path.join(full_output_folder, file)
|
||||
|
||||
# Use original sample rate initially
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
# Handle Opus sample rate requirements
|
||||
if format == "opus":
|
||||
if sample_rate > 48000:
|
||||
sample_rate = 48000
|
||||
elif sample_rate not in AudioSaveHelper._OPUS_RATES:
|
||||
# Find the next highest supported rate
|
||||
for rate in sorted(AudioSaveHelper._OPUS_RATES):
|
||||
if rate > sample_rate:
|
||||
sample_rate = rate
|
||||
break
|
||||
if sample_rate not in AudioSaveHelper._OPUS_RATES: # Fallback if still not supported
|
||||
sample_rate = 48000
|
||||
|
||||
# Resample if necessary
|
||||
if sample_rate != audio["sample_rate"]:
|
||||
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
||||
|
||||
# Create output with specified format
|
||||
output_buffer = BytesIO()
|
||||
output_container = av.open(output_buffer, mode="w", format=format)
|
||||
|
||||
# Set metadata on the container
|
||||
for key, value in metadata.items():
|
||||
output_container.metadata[key] = value
|
||||
|
||||
# Set up the output stream with appropriate properties
|
||||
if format == "opus":
|
||||
out_stream = output_container.add_stream("libopus", rate=sample_rate)
|
||||
if quality == "64k":
|
||||
out_stream.bit_rate = 64000
|
||||
elif quality == "96k":
|
||||
out_stream.bit_rate = 96000
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "192k":
|
||||
out_stream.bit_rate = 192000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
elif format == "mp3":
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
|
||||
if quality == "V0":
|
||||
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||
out_stream.codec_context.qscale = 1
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
else: # format == "flac":
|
||||
out_stream = output_container.add_stream("flac", rate=sample_rate)
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(
|
||||
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
|
||||
format="flt",
|
||||
layout="mono" if waveform.shape[0] == 1 else "stereo",
|
||||
)
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
|
||||
# Flush encoder
|
||||
output_container.mux(out_stream.encode(None))
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
|
||||
# Write the output to file
|
||||
output_buffer.seek(0)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(output_buffer.getbuffer())
|
||||
|
||||
results.append(SavedResult(file, subfolder, folder_type))
|
||||
counter += 1
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_save_audio_ui(
|
||||
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
|
||||
) -> SavedAudios:
|
||||
"""Save and instantly wrap for UI."""
|
||||
return SavedAudios(
|
||||
AudioSaveHelper.save_audio(
|
||||
audio,
|
||||
filename_prefix=filename_prefix,
|
||||
folder_type=FolderType.output,
|
||||
cls=cls,
|
||||
format=format,
|
||||
quality=quality,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class PreviewImage(_UIOutput):
|
||||
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
|
||||
self.values = ImageSaveHelper.save_images(
|
||||
image,
|
||||
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
|
||||
folder_type=FolderType.temp,
|
||||
cls=cls,
|
||||
compress_level=1,
|
||||
)
|
||||
self.animated = animated
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"images": self.values,
|
||||
"animated": (self.animated,)
|
||||
}
|
||||
|
||||
|
||||
class PreviewMask(PreviewImage):
|
||||
def __init__(self, mask: PreviewMask.Type, animated: bool=False, cls: ComfyNode=None, **kwargs):
|
||||
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||
super().__init__(preview, animated, cls, **kwargs)
|
||||
|
||||
|
||||
class PreviewAudio(_UIOutput):
|
||||
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
|
||||
self.values = AudioSaveHelper.save_audio(
|
||||
audio,
|
||||
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
|
||||
folder_type=FolderType.temp,
|
||||
cls=cls,
|
||||
format="flac",
|
||||
quality="128k",
|
||||
)
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
return {"audio": self.values}
|
||||
|
||||
|
||||
class PreviewVideo(_UIOutput):
|
||||
def __init__(self, values: list[SavedResult | dict], **kwargs):
|
||||
self.values = values
|
||||
|
||||
def as_dict(self):
|
||||
return {"images": self.values, "animated": (True,)}
|
||||
|
||||
|
||||
class PreviewUI3D(_UIOutput):
|
||||
def __init__(self, model_file, camera_info, **kwargs):
|
||||
self.model_file = model_file
|
||||
self.camera_info = camera_info
|
||||
|
||||
def as_dict(self):
|
||||
return {"result": [self.model_file, self.camera_info]}
|
||||
|
||||
|
||||
class PreviewText(_UIOutput):
|
||||
def __init__(self, value: str, **kwargs):
|
||||
self.value = value
|
||||
|
||||
def as_dict(self):
|
||||
return {"text": (self.value,)}
|
||||
|
||||
|
||||
class _UI:
|
||||
SavedResult = SavedResult
|
||||
SavedImages = SavedImages
|
||||
SavedAudios = SavedAudios
|
||||
ImageSaveHelper = ImageSaveHelper
|
||||
AudioSaveHelper = AudioSaveHelper
|
||||
PreviewImage = PreviewImage
|
||||
PreviewMask = PreviewMask
|
||||
PreviewAudio = PreviewAudio
|
||||
PreviewVideo = PreviewVideo
|
||||
PreviewUI3D = PreviewUI3D
|
||||
PreviewText = PreviewText
|
||||
8
comfy_api/latest/_util/__init__.py
Normal file
8
comfy_api/latest/_util/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
__all__ = [
|
||||
# Utility Types
|
||||
"VideoContainer",
|
||||
"VideoCodec",
|
||||
"VideoComponents",
|
||||
]
|
||||
52
comfy_api/latest/_util/video_types.py
Normal file
52
comfy_api/latest/_util/video_types.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.latest._input import ImageInput, AudioInput
|
||||
|
||||
class VideoCodec(str, Enum):
|
||||
AUTO = "auto"
|
||||
H264 = "h264"
|
||||
|
||||
@classmethod
|
||||
def as_input(cls) -> list[str]:
|
||||
"""
|
||||
Returns a list of codec names that can be used as node input.
|
||||
"""
|
||||
return [member.value for member in cls]
|
||||
|
||||
class VideoContainer(str, Enum):
|
||||
AUTO = "auto"
|
||||
MP4 = "mp4"
|
||||
|
||||
@classmethod
|
||||
def as_input(cls) -> list[str]:
|
||||
"""
|
||||
Returns a list of container names that can be used as node input.
|
||||
"""
|
||||
return [member.value for member in cls]
|
||||
|
||||
@classmethod
|
||||
def get_extension(cls, value) -> str:
|
||||
"""
|
||||
Returns the file extension for the container.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
value = cls(value)
|
||||
if value == VideoContainer.MP4 or value == VideoContainer.AUTO:
|
||||
return "mp4"
|
||||
return ""
|
||||
|
||||
@dataclass
|
||||
class VideoComponents:
|
||||
"""
|
||||
Dataclass representing the components of a video.
|
||||
"""
|
||||
|
||||
images: ImageInput
|
||||
frame_rate: Fraction
|
||||
audio: Optional[AudioInput] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
20
comfy_api/latest/generated/ComfyAPISyncStub.pyi
Normal file
20
comfy_api/latest/generated/ComfyAPISyncStub.pyi
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
|
||||
from comfy_api.latest import ComfyAPI_latest
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
class ComfyAPISyncStub:
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
class ExecutionSync:
|
||||
def __init__(self) -> None: ...
|
||||
"""
|
||||
Update the progress bar displayed in the ComfyUI interface.
|
||||
|
||||
This function allows custom nodes and API calls to report their progress
|
||||
back to the user interface, providing visual feedback during long operations.
|
||||
|
||||
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||
"""
|
||||
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ...
|
||||
|
||||
execution: ExecutionSync
|
||||
8
comfy_api/util.py
Normal file
8
comfy_api/util.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
||||
|
||||
__all__ = [
|
||||
"VideoCodec",
|
||||
"VideoContainer",
|
||||
"VideoComponents",
|
||||
]
|
||||
@@ -1,7 +1,7 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
__all__ = [
|
||||
# Utility Types
|
||||
"VideoContainer",
|
||||
"VideoCodec",
|
||||
"VideoComponents",
|
||||
|
||||
@@ -1,51 +1,12 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.input import ImageInput, AudioInput
|
||||
|
||||
class VideoCodec(str, Enum):
|
||||
AUTO = "auto"
|
||||
H264 = "h264"
|
||||
|
||||
@classmethod
|
||||
def as_input(cls) -> list[str]:
|
||||
"""
|
||||
Returns a list of codec names that can be used as node input.
|
||||
"""
|
||||
return [member.value for member in cls]
|
||||
|
||||
class VideoContainer(str, Enum):
|
||||
AUTO = "auto"
|
||||
MP4 = "mp4"
|
||||
|
||||
@classmethod
|
||||
def as_input(cls) -> list[str]:
|
||||
"""
|
||||
Returns a list of container names that can be used as node input.
|
||||
"""
|
||||
return [member.value for member in cls]
|
||||
|
||||
@classmethod
|
||||
def get_extension(cls, value) -> str:
|
||||
"""
|
||||
Returns the file extension for the container.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
value = cls(value)
|
||||
if value == VideoContainer.MP4 or value == VideoContainer.AUTO:
|
||||
return "mp4"
|
||||
return ""
|
||||
|
||||
@dataclass
|
||||
class VideoComponents:
|
||||
"""
|
||||
Dataclass representing the components of a video.
|
||||
"""
|
||||
|
||||
images: ImageInput
|
||||
frame_rate: Fraction
|
||||
audio: Optional[AudioInput] = None
|
||||
metadata: Optional[dict] = None
|
||||
# This file only exists for backwards compatibility.
|
||||
from comfy_api.latest._util.video_types import (
|
||||
VideoContainer,
|
||||
VideoCodec,
|
||||
VideoComponents,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"VideoContainer",
|
||||
"VideoCodec",
|
||||
"VideoComponents",
|
||||
]
|
||||
|
||||
42
comfy_api/v0_0_1/__init__.py
Normal file
42
comfy_api/v0_0_1/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from comfy_api.v0_0_2 import (
|
||||
ComfyAPIAdapter_v0_0_2,
|
||||
Input as Input_v0_0_2,
|
||||
InputImpl as InputImpl_v0_0_2,
|
||||
Types as Types_v0_0_2,
|
||||
)
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
|
||||
|
||||
# This version only exists to serve as a template for future version adapters.
|
||||
# There is no reason anyone should ever use it.
|
||||
class ComfyAPIAdapter_v0_0_1(ComfyAPIAdapter_v0_0_2):
|
||||
VERSION = "0.0.1"
|
||||
STABLE = True
|
||||
|
||||
class Input(Input_v0_0_2):
|
||||
pass
|
||||
|
||||
class InputImpl(InputImpl_v0_0_2):
|
||||
pass
|
||||
|
||||
class Types(Types_v0_0_2):
|
||||
pass
|
||||
|
||||
ComfyAPI = ComfyAPIAdapter_v0_0_1
|
||||
|
||||
# Create a synchronous version of the API
|
||||
if TYPE_CHECKING:
|
||||
from comfy_api.v0_0_1.generated.ComfyAPISyncStub import ComfyAPISyncStub # type: ignore
|
||||
|
||||
ComfyAPISync: Type[ComfyAPISyncStub]
|
||||
|
||||
ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_1)
|
||||
|
||||
__all__ = [
|
||||
"ComfyAPI",
|
||||
"ComfyAPISync",
|
||||
"Input",
|
||||
"InputImpl",
|
||||
"Types",
|
||||
]
|
||||
20
comfy_api/v0_0_1/generated/ComfyAPISyncStub.pyi
Normal file
20
comfy_api/v0_0_1/generated/ComfyAPISyncStub.pyi
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
|
||||
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
class ComfyAPISyncStub:
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
class ExecutionSync:
|
||||
def __init__(self) -> None: ...
|
||||
"""
|
||||
Update the progress bar displayed in the ComfyUI interface.
|
||||
|
||||
This function allows custom nodes and API calls to report their progress
|
||||
back to the user interface, providing visual feedback during long operations.
|
||||
|
||||
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||
"""
|
||||
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ...
|
||||
|
||||
execution: ExecutionSync
|
||||
45
comfy_api/v0_0_2/__init__.py
Normal file
45
comfy_api/v0_0_2/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from comfy_api.latest import (
|
||||
ComfyAPI_latest,
|
||||
Input as Input_latest,
|
||||
InputImpl as InputImpl_latest,
|
||||
Types as Types_latest,
|
||||
)
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest import io, ui, ComfyExtension #noqa: F401
|
||||
|
||||
|
||||
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
|
||||
VERSION = "0.0.2"
|
||||
STABLE = False
|
||||
|
||||
|
||||
class Input(Input_latest):
|
||||
pass
|
||||
|
||||
|
||||
class InputImpl(InputImpl_latest):
|
||||
pass
|
||||
|
||||
|
||||
class Types(Types_latest):
|
||||
pass
|
||||
|
||||
|
||||
ComfyAPI = ComfyAPIAdapter_v0_0_2
|
||||
|
||||
# Create a synchronous version of the API
|
||||
if TYPE_CHECKING:
|
||||
from comfy_api.v0_0_2.generated.ComfyAPISyncStub import ComfyAPISyncStub # type: ignore
|
||||
|
||||
ComfyAPISync: Type[ComfyAPISyncStub]
|
||||
ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_2)
|
||||
|
||||
__all__ = [
|
||||
"ComfyAPI",
|
||||
"ComfyAPISync",
|
||||
"Input",
|
||||
"InputImpl",
|
||||
"Types",
|
||||
"ComfyExtension",
|
||||
]
|
||||
20
comfy_api/v0_0_2/generated/ComfyAPISyncStub.pyi
Normal file
20
comfy_api/v0_0_2/generated/ComfyAPISyncStub.pyi
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
|
||||
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
class ComfyAPISyncStub:
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
class ExecutionSync:
|
||||
def __init__(self) -> None: ...
|
||||
"""
|
||||
Update the progress bar displayed in the ComfyUI interface.
|
||||
|
||||
This function allows custom nodes and API calls to report their progress
|
||||
back to the user interface, providing visual feedback during long operations.
|
||||
|
||||
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||
"""
|
||||
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ...
|
||||
|
||||
execution: ExecutionSync
|
||||
12
comfy_api/version_list.py
Normal file
12
comfy_api/version_list.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from comfy_api.latest import ComfyAPI_latest
|
||||
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
|
||||
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
|
||||
from comfy_api.internal import ComfyAPIBase
|
||||
from typing import List, Type
|
||||
|
||||
supported_versions: List[Type[ComfyAPIBase]] = [
|
||||
ComfyAPI_latest,
|
||||
ComfyAPIAdapter_v0_0_2,
|
||||
ComfyAPIAdapter_v0_0_1,
|
||||
]
|
||||
|
||||
2656
comfy_api_nodes/apis/__init__.py
generated
2656
comfy_api_nodes/apis/__init__.py
generated
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import datetime
|
||||
import json
|
||||
|
||||
@@ -127,7 +127,7 @@ class TripoTextToModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task')
|
||||
prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024)
|
||||
negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024)
|
||||
model_version: Optional[TripoModelVersion] = TripoModelVersion.V2_5
|
||||
model_version: Optional[TripoModelVersion] = TripoModelVersion.v2_5_20250123
|
||||
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
|
||||
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
|
||||
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
API Nodes for Gemini Multimodal LLM Usage via Remote API
|
||||
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
@@ -8,10 +8,10 @@ from typing import Optional
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis import (
|
||||
Veo2GenVidRequest,
|
||||
Veo2GenVidResponse,
|
||||
Veo2GenVidPollRequest,
|
||||
Veo2GenVidPollResponse
|
||||
VeoGenVidRequest,
|
||||
VeoGenVidResponse,
|
||||
VeoGenVidPollRequest,
|
||||
VeoGenVidPollResponse
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
@@ -35,7 +35,7 @@ def convert_image_to_base64(image: torch.Tensor):
|
||||
return tensor_to_base64_string(scaled_image)
|
||||
|
||||
|
||||
def get_video_url_from_response(poll_response: Veo2GenVidPollResponse) -> Optional[str]:
|
||||
def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optional[str]:
|
||||
if (
|
||||
poll_response.response
|
||||
and hasattr(poll_response.response, "videos")
|
||||
@@ -130,6 +130,14 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image to guide video generation",
|
||||
}),
|
||||
"model": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["veo-2.0-generate-001"],
|
||||
"default": "veo-2.0-generate-001",
|
||||
"tooltip": "Veo 2 model to use for video generation",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
@@ -141,7 +149,7 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/Veo"
|
||||
DESCRIPTION = "Generates videos from text prompts using Google's Veo API"
|
||||
DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API"
|
||||
API_NODE = True
|
||||
|
||||
def generate_video(
|
||||
@@ -154,6 +162,8 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
person_generation="ALLOW",
|
||||
seed=0,
|
||||
image=None,
|
||||
model="veo-2.0-generate-001",
|
||||
generate_audio=False,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -188,16 +198,19 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
parameters["negativePrompt"] = negative_prompt
|
||||
if seed > 0:
|
||||
parameters["seed"] = seed
|
||||
# Only add generateAudio for Veo 3 models
|
||||
if "veo-3.0" in model:
|
||||
parameters["generateAudio"] = generate_audio
|
||||
|
||||
# Initial request to start video generation
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/veo/generate",
|
||||
path=f"/proxy/veo/{model}/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=Veo2GenVidRequest,
|
||||
response_model=Veo2GenVidResponse
|
||||
request_model=VeoGenVidRequest,
|
||||
response_model=VeoGenVidResponse
|
||||
),
|
||||
request=Veo2GenVidRequest(
|
||||
request=VeoGenVidRequest(
|
||||
instances=instances,
|
||||
parameters=parameters
|
||||
),
|
||||
@@ -223,16 +236,16 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
# Define the polling operation
|
||||
poll_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path="/proxy/veo/poll",
|
||||
path=f"/proxy/veo/{model}/poll",
|
||||
method=HttpMethod.POST,
|
||||
request_model=Veo2GenVidPollRequest,
|
||||
response_model=Veo2GenVidPollResponse
|
||||
request_model=VeoGenVidPollRequest,
|
||||
response_model=VeoGenVidPollResponse
|
||||
),
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=[], # No failed statuses, we'll handle errors after polling
|
||||
status_extractor=status_extractor,
|
||||
progress_extractor=progress_extractor,
|
||||
request=Veo2GenVidPollRequest(
|
||||
request=VeoGenVidPollRequest(
|
||||
operationName=operation_name
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
@@ -298,11 +311,64 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
return (VideoFromFile(video_io),)
|
||||
|
||||
|
||||
# Register the node
|
||||
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
"""
|
||||
Generates videos from text prompts using Google's Veo 3 API.
|
||||
|
||||
Supported models:
|
||||
- veo-3.0-generate-001
|
||||
- veo-3.0-fast-generate-001
|
||||
|
||||
This node extends the base Veo node with Veo 3 specific features including
|
||||
audio generation and fixed 8-second duration.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
parent_input = super().INPUT_TYPES()
|
||||
|
||||
# Update model options for Veo 3
|
||||
parent_input["optional"]["model"] = (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["veo-3.0-generate-001", "veo-3.0-fast-generate-001"],
|
||||
"default": "veo-3.0-generate-001",
|
||||
"tooltip": "Veo 3 model to use for video generation",
|
||||
},
|
||||
)
|
||||
|
||||
# Add generateAudio parameter
|
||||
parent_input["optional"]["generate_audio"] = (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Generate audio for the video. Supported by all Veo 3 models.",
|
||||
}
|
||||
)
|
||||
|
||||
# Update duration constraints for Veo 3 (only 8 seconds supported)
|
||||
parent_input["optional"]["duration_seconds"] = (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 8,
|
||||
"min": 8,
|
||||
"max": 8,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
|
||||
},
|
||||
)
|
||||
|
||||
return parent_input
|
||||
|
||||
|
||||
# Register the nodes
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"VeoVideoGenerationNode": VeoVideoGenerationNode,
|
||||
"Veo3VideoGenerationNode": Veo3VideoGenerationNode,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"VeoVideoGenerationNode": "Google Veo2 Video Generation",
|
||||
"VeoVideoGenerationNode": "Google Veo 2 Video Generation",
|
||||
"Veo3VideoGenerationNode": "Google Veo 3 Video Generation",
|
||||
}
|
||||
|
||||
@@ -4,9 +4,12 @@ from typing import Type, Literal
|
||||
import nodes
|
||||
import asyncio
|
||||
import inspect
|
||||
from comfy_execution.graph_utils import is_link
|
||||
from comfy_execution.graph_utils import is_link, ExecutionBlocker
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||
|
||||
# NOTE: ExecutionBlocker code got moved to graph_utils.py to prevent torch being imported too soon during unit tests
|
||||
ExecutionBlocker = ExecutionBlocker
|
||||
|
||||
class DependencyCycleError(Exception):
|
||||
pass
|
||||
|
||||
@@ -294,21 +297,3 @@ class ExecutionList(TopologicalSort):
|
||||
del blocked_by[node_id]
|
||||
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
||||
return list(blocked_by.keys())
|
||||
|
||||
class ExecutionBlocker:
|
||||
"""
|
||||
Return this from a node and any users will be blocked with the given error message.
|
||||
If the message is None, execution will be blocked silently instead.
|
||||
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
|
||||
possible, a lazy input will be more efficient and have a better user experience.
|
||||
This functionality is useful in two cases:
|
||||
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
|
||||
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
|
||||
lazy evaluation to let it conditionally disable itself.)
|
||||
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
|
||||
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
|
||||
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
|
||||
"""
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
||||
|
||||
@@ -137,3 +137,19 @@ def add_graph_prefix(graph, outputs, prefix):
|
||||
|
||||
return new_graph, tuple(new_outputs)
|
||||
|
||||
class ExecutionBlocker:
|
||||
"""
|
||||
Return this from a node and any users will be blocked with the given error message.
|
||||
If the message is None, execution will be blocked silently instead.
|
||||
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
|
||||
possible, a lazy input will be more efficient and have a better user experience.
|
||||
This functionality is useful in two cases:
|
||||
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
|
||||
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
|
||||
lazy evaluation to let it conditionally disable itself.)
|
||||
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
|
||||
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
|
||||
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
|
||||
"""
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from typing import TypedDict, Dict, Optional
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict, Dict, Optional, Tuple
|
||||
from typing_extensions import override
|
||||
from PIL import Image
|
||||
from enum import Enum
|
||||
@@ -10,6 +12,7 @@ if TYPE_CHECKING:
|
||||
from protocol import BinaryEventTypes
|
||||
from comfy_api import feature_flags
|
||||
|
||||
PreviewImageTuple = Tuple[str, Image.Image, Optional[int]]
|
||||
|
||||
class NodeState(Enum):
|
||||
Pending = "pending"
|
||||
@@ -52,7 +55,7 @@ class ProgressHandler(ABC):
|
||||
max_value: float,
|
||||
state: NodeProgressState,
|
||||
prompt_id: str,
|
||||
image: Optional[Image.Image] = None,
|
||||
image: PreviewImageTuple | None = None,
|
||||
):
|
||||
"""Called when a node's progress is updated"""
|
||||
pass
|
||||
@@ -103,7 +106,7 @@ class CLIProgressHandler(ProgressHandler):
|
||||
max_value: float,
|
||||
state: NodeProgressState,
|
||||
prompt_id: str,
|
||||
image: Optional[Image.Image] = None,
|
||||
image: PreviewImageTuple | None = None,
|
||||
):
|
||||
# Handle case where start_handler wasn't called
|
||||
if node_id not in self.progress_bars:
|
||||
@@ -196,7 +199,7 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
max_value: float,
|
||||
state: NodeProgressState,
|
||||
prompt_id: str,
|
||||
image: Optional[Image.Image] = None,
|
||||
image: PreviewImageTuple | None = None,
|
||||
):
|
||||
# Send progress state of all nodes
|
||||
if self.registry:
|
||||
@@ -231,7 +234,6 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
if self.registry:
|
||||
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||
|
||||
|
||||
class ProgressRegistry:
|
||||
"""
|
||||
Registry that maintains node progress state and notifies registered handlers.
|
||||
@@ -285,7 +287,7 @@ class ProgressRegistry:
|
||||
handler.start_handler(node_id, entry, self.prompt_id)
|
||||
|
||||
def update_progress(
|
||||
self, node_id: str, value: float, max_value: float, image: Optional[Image.Image]
|
||||
self, node_id: str, value: float, max_value: float, image: PreviewImageTuple | None = None
|
||||
) -> None:
|
||||
"""Update progress for a node"""
|
||||
entry = self.ensure_entry(node_id)
|
||||
@@ -317,7 +319,7 @@ class ProgressRegistry:
|
||||
handler.reset()
|
||||
|
||||
# Global registry instance
|
||||
global_progress_registry: ProgressRegistry = None
|
||||
global_progress_registry: ProgressRegistry | None = None
|
||||
|
||||
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
|
||||
global global_progress_registry
|
||||
|
||||
@@ -8,9 +8,7 @@ import json
|
||||
from typing import Optional, Literal
|
||||
from fractions import Fraction
|
||||
from comfy.comfy_types import IO, FileLocator, ComfyNodeABC
|
||||
from comfy_api.input import ImageInput, AudioInput, VideoInput
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
from comfy_api.input_impl import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.latest import Input, InputImpl, Types
|
||||
from comfy.cli_args import args
|
||||
|
||||
class SaveWEBM:
|
||||
@@ -91,8 +89,8 @@ class SaveVideo(ComfyNodeABC):
|
||||
"required": {
|
||||
"video": (IO.VIDEO, {"tooltip": "The video to save."}),
|
||||
"filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}),
|
||||
"format": (VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}),
|
||||
"codec": (VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}),
|
||||
"format": (Types.VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}),
|
||||
"codec": (Types.VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}),
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": "PROMPT",
|
||||
@@ -108,7 +106,7 @@ class SaveVideo(ComfyNodeABC):
|
||||
CATEGORY = "image/video"
|
||||
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
|
||||
|
||||
def save_video(self, video: VideoInput, filename_prefix, format, codec, prompt=None, extra_pnginfo=None):
|
||||
def save_video(self, video: Input.Video, filename_prefix, format, codec, prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
width, height = video.get_dimensions()
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||
@@ -127,7 +125,7 @@ class SaveVideo(ComfyNodeABC):
|
||||
metadata["prompt"] = prompt
|
||||
if len(metadata) > 0:
|
||||
saved_metadata = metadata
|
||||
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
|
||||
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
|
||||
video.save_to(
|
||||
os.path.join(full_output_folder, file),
|
||||
format=format,
|
||||
@@ -163,9 +161,9 @@ class CreateVideo(ComfyNodeABC):
|
||||
CATEGORY = "image/video"
|
||||
DESCRIPTION = "Create a video from images."
|
||||
|
||||
def create_video(self, images: ImageInput, fps: float, audio: Optional[AudioInput] = None):
|
||||
return (VideoFromComponents(
|
||||
VideoComponents(
|
||||
def create_video(self, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None):
|
||||
return (InputImpl.VideoFromComponents(
|
||||
Types.VideoComponents(
|
||||
images=images,
|
||||
audio=audio,
|
||||
frame_rate=Fraction(fps),
|
||||
@@ -187,7 +185,7 @@ class GetVideoComponents(ComfyNodeABC):
|
||||
CATEGORY = "image/video"
|
||||
DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate."
|
||||
|
||||
def get_components(self, video: VideoInput):
|
||||
def get_components(self, video: Input.Video):
|
||||
components = video.get_components()
|
||||
|
||||
return (components.images, components.audio, float(components.frame_rate))
|
||||
@@ -208,7 +206,7 @@ class LoadVideo(ComfyNodeABC):
|
||||
FUNCTION = "load_video"
|
||||
def load_video(self, file):
|
||||
video_path = folder_paths.get_annotated_filepath(file)
|
||||
return (VideoFromFile(video_path),)
|
||||
return (InputImpl.VideoFromFile(video_path),)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, file):
|
||||
@@ -239,3 +237,4 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"GetVideoComponents": "Get Video Components",
|
||||
"LoadVideo": "Load Video",
|
||||
}
|
||||
|
||||
|
||||
@@ -149,6 +149,7 @@ class WanFirstLastFrameToVideo:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
|
||||
clip_vision_output = None
|
||||
if clip_vision_start_image is not None:
|
||||
clip_vision_output = clip_vision_start_image
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.47"
|
||||
__version__ = "0.3.49"
|
||||
|
||||
147
execution.py
147
execution.py
@@ -7,7 +7,7 @@ import threading
|
||||
import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import List, Literal, NamedTuple, Optional
|
||||
from typing import List, Literal, NamedTuple, Optional, Union
|
||||
import asyncio
|
||||
|
||||
import torch
|
||||
@@ -32,6 +32,8 @@ from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||
from comfy_api.latest import io
|
||||
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
@@ -56,7 +58,15 @@ class IsChangedCache:
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if not hasattr(class_def, "IS_CHANGED"):
|
||||
has_is_changed = False
|
||||
is_changed_name = None
|
||||
if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None:
|
||||
has_is_changed = True
|
||||
is_changed_name = "fingerprint_inputs"
|
||||
elif hasattr(class_def, "IS_CHANGED"):
|
||||
has_is_changed = True
|
||||
is_changed_name = "IS_CHANGED"
|
||||
if not has_is_changed:
|
||||
self.is_changed[node_id] = False
|
||||
return self.is_changed[node_id]
|
||||
|
||||
@@ -65,9 +75,9 @@ class IsChangedCache:
|
||||
return self.is_changed[node_id]
|
||||
|
||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
try:
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, "IS_CHANGED")
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
|
||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
except Exception as e:
|
||||
@@ -126,9 +136,14 @@ class CacheSet:
|
||||
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
||||
if is_v3:
|
||||
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
|
||||
else:
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_data_all = {}
|
||||
missing_keys = {}
|
||||
hidden_inputs_v3 = {}
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||
@@ -153,22 +168,37 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
||||
elif input_category is not None:
|
||||
input_data_all[x] = [input_data]
|
||||
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
for x in h:
|
||||
if h[x] == "PROMPT":
|
||||
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
|
||||
if h[x] == "DYNPROMPT":
|
||||
input_data_all[x] = [dynprompt]
|
||||
if h[x] == "EXTRA_PNGINFO":
|
||||
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
||||
if h[x] == "UNIQUE_ID":
|
||||
input_data_all[x] = [unique_id]
|
||||
if h[x] == "AUTH_TOKEN_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||
if h[x] == "API_KEY_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||
return input_data_all, missing_keys
|
||||
if is_v3:
|
||||
if schema.hidden:
|
||||
if io.Hidden.prompt in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
|
||||
if io.Hidden.dynprompt in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt
|
||||
if io.Hidden.extra_pnginfo in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
|
||||
if io.Hidden.unique_id in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.unique_id] = unique_id
|
||||
if io.Hidden.auth_token_comfy_org in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
||||
if io.Hidden.api_key_comfy_org in schema.hidden:
|
||||
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
||||
else:
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
for x in h:
|
||||
if h[x] == "PROMPT":
|
||||
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
|
||||
if h[x] == "DYNPROMPT":
|
||||
input_data_all[x] = [dynprompt]
|
||||
if h[x] == "EXTRA_PNGINFO":
|
||||
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
||||
if h[x] == "UNIQUE_ID":
|
||||
input_data_all[x] = [unique_id]
|
||||
if h[x] == "AUTH_TOKEN_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||
if h[x] == "API_KEY_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||
return input_data_all, missing_keys, hidden_inputs_v3
|
||||
|
||||
map_node_over_list = None #Don't hook this please
|
||||
|
||||
@@ -184,7 +214,7 @@ async def resolve_map_node_over_list_results(results):
|
||||
raise exc
|
||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||
|
||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||
# check if node wants the lists
|
||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||
|
||||
@@ -214,7 +244,22 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
||||
if execution_block is None:
|
||||
if pre_execute_cb is not None and index is not None:
|
||||
pre_execute_cb(index)
|
||||
f = getattr(obj, func)
|
||||
# V3
|
||||
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
|
||||
# if is just a class, then assign no resources or state, just create clone
|
||||
if is_class(obj):
|
||||
type_obj = obj
|
||||
obj.VALIDATE_CLASS()
|
||||
class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs)
|
||||
# otherwise, use class instance to populate/reuse some fields
|
||||
else:
|
||||
type_obj = type(obj)
|
||||
type_obj.VALIDATE_CLASS()
|
||||
class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs)
|
||||
f = make_locked_method_func(type_obj, func, class_clone)
|
||||
# V1
|
||||
else:
|
||||
f = getattr(obj, func)
|
||||
if inspect.iscoroutinefunction(f):
|
||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
||||
with CurrentNodeContext(prompt_id, unique_id, list_index):
|
||||
@@ -266,8 +311,8 @@ def merge_result_data(results, obj):
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||
if has_pending_task:
|
||||
return return_values, {}, False, has_pending_task
|
||||
@@ -298,6 +343,26 @@ def get_output_from_returns(return_values, obj):
|
||||
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||
results.append(result)
|
||||
subgraph_results.append((None, result))
|
||||
elif isinstance(r, _NodeOutputInternal):
|
||||
# V3
|
||||
if r.ui is not None:
|
||||
if isinstance(r.ui, dict):
|
||||
uis.append(r.ui)
|
||||
else:
|
||||
uis.append(r.ui.as_dict())
|
||||
if r.expand is not None:
|
||||
has_subgraph = True
|
||||
new_graph = r.expand
|
||||
result = r.result
|
||||
if r.block_execution is not None:
|
||||
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
|
||||
subgraph_results.append((new_graph, result))
|
||||
elif r.result is not None:
|
||||
result = r.result
|
||||
if r.block_execution is not None:
|
||||
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
|
||||
results.append(result)
|
||||
subgraph_results.append((None, result))
|
||||
else:
|
||||
if isinstance(r, ExecutionBlocker):
|
||||
r = tuple([r] * len(obj.RETURN_TYPES))
|
||||
@@ -381,7 +446,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
has_subgraph = False
|
||||
else:
|
||||
get_progress_state().start_progress(unique_id)
|
||||
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
@@ -391,8 +456,12 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
obj = class_def()
|
||||
caches.objects.set(unique_id, obj)
|
||||
|
||||
if hasattr(obj, "check_lazy_status"):
|
||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||
if issubclass(class_def, _ComfyNodeInternal):
|
||||
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
|
||||
else:
|
||||
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
||||
if lazy_status_present:
|
||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
|
||||
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||
@@ -424,7 +493,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
def pre_execute_cb(call_index):
|
||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||
if has_pending_tasks:
|
||||
pending_async_nodes[unique_id] = output_data
|
||||
unblock = execution_list.add_external_block(unique_id)
|
||||
@@ -672,8 +741,14 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
|
||||
validate_function_inputs = []
|
||||
validate_has_kwargs = False
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS)
|
||||
if issubclass(obj_class, _ComfyNodeInternal):
|
||||
validate_function_name = "validate_inputs"
|
||||
validate_function = first_real_override(obj_class, validate_function_name)
|
||||
else:
|
||||
validate_function_name = "VALIDATE_INPUTS"
|
||||
validate_function = getattr(obj_class, validate_function_name, None)
|
||||
if validate_function is not None:
|
||||
argspec = inspect.getfullargspec(validate_function)
|
||||
validate_function_inputs = argspec.args
|
||||
validate_has_kwargs = argspec.varkw is not None
|
||||
received_types = {}
|
||||
@@ -848,7 +923,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
continue
|
||||
|
||||
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
||||
input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
|
||||
input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id)
|
||||
input_filtered = {}
|
||||
for x in input_data_all:
|
||||
if x in validate_function_inputs or validate_has_kwargs:
|
||||
@@ -856,8 +931,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
if 'input_types' in validate_function_inputs:
|
||||
input_filtered['input_types'] = [received_types]
|
||||
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
||||
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
|
||||
ret = await resolve_map_node_over_list_results(ret)
|
||||
for x in input_filtered:
|
||||
for i, r in enumerate(ret):
|
||||
@@ -891,7 +965,7 @@ def full_type_name(klass):
|
||||
return klass.__qualname__
|
||||
return module + '.' + klass.__qualname__
|
||||
|
||||
async def validate_prompt(prompt_id, prompt):
|
||||
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
|
||||
outputs = set()
|
||||
for x in prompt:
|
||||
if 'class_type' not in prompt[x]:
|
||||
@@ -915,7 +989,8 @@ async def validate_prompt(prompt_id, prompt):
|
||||
return (False, error, [], {})
|
||||
|
||||
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
||||
outputs.add(x)
|
||||
if partial_execution_list is None or x in partial_execution_list:
|
||||
outputs.add(x)
|
||||
|
||||
if len(outputs) == 0:
|
||||
error = {
|
||||
|
||||
4
main.py
4
main.py
@@ -313,10 +313,10 @@ def start_comfyui(asyncio_loop=None):
|
||||
prompt_server = server.PromptServer(asyncio_loop)
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
nodes.init_extra_nodes(
|
||||
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
|
||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||
init_api_nodes=not args.disable_api_nodes
|
||||
)
|
||||
))
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
cuda_malloc_warning()
|
||||
|
||||
75
nodes.py
75
nodes.py
@@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import hashlib
|
||||
import inspect
|
||||
import traceback
|
||||
import math
|
||||
import time
|
||||
@@ -26,6 +28,9 @@ import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.controlnet
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
|
||||
from comfy_api.internal import register_versions, ComfyAPIWithVersion
|
||||
from comfy_api.version_list import supported_versions
|
||||
from comfy_api.latest import io, ComfyExtension
|
||||
|
||||
import comfy.clip_vision
|
||||
|
||||
@@ -920,7 +925,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@@ -2101,7 +2106,7 @@ def get_module_name(module_path: str) -> str:
|
||||
return base_path
|
||||
|
||||
|
||||
def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
|
||||
async def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
|
||||
module_name = get_module_name(module_path)
|
||||
if os.path.isfile(module_path):
|
||||
sp = os.path.splitext(module_path)
|
||||
@@ -2149,6 +2154,7 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
|
||||
if os.path.isdir(web_dir):
|
||||
EXTENSION_WEB_DIRS[module_name] = web_dir
|
||||
|
||||
# V1 node definition
|
||||
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
|
||||
for name, node_cls in module.NODE_CLASS_MAPPINGS.items():
|
||||
if name not in ignore:
|
||||
@@ -2157,15 +2163,45 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
|
||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
||||
return True
|
||||
# V3 Extension Definition
|
||||
elif hasattr(module, "comfy_entrypoint"):
|
||||
entrypoint = getattr(module, "comfy_entrypoint")
|
||||
if not callable(entrypoint):
|
||||
logging.warning(f"comfy_entrypoint in {module_path} is not callable, skipping.")
|
||||
return False
|
||||
try:
|
||||
if inspect.iscoroutinefunction(entrypoint):
|
||||
extension = await entrypoint()
|
||||
else:
|
||||
extension = entrypoint()
|
||||
if not isinstance(extension, ComfyExtension):
|
||||
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
|
||||
return False
|
||||
node_list = await extension.get_node_list()
|
||||
if not isinstance(node_list, list):
|
||||
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
|
||||
return False
|
||||
for node_cls in node_list:
|
||||
node_cls: io.ComfyNode
|
||||
schema = node_cls.GET_SCHEMA()
|
||||
if schema.node_id not in ignore:
|
||||
NODE_CLASS_MAPPINGS[schema.node_id] = node_cls
|
||||
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format(module_parent, get_module_name(module_path))
|
||||
if schema.display_name is not None:
|
||||
NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")
|
||||
return False
|
||||
else:
|
||||
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
|
||||
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or NODES_LIST (need one).")
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.warning(traceback.format_exc())
|
||||
logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
|
||||
return False
|
||||
|
||||
def init_external_custom_nodes():
|
||||
async def init_external_custom_nodes():
|
||||
"""
|
||||
Initializes the external custom nodes.
|
||||
|
||||
@@ -2191,7 +2227,7 @@ def init_external_custom_nodes():
|
||||
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
||||
continue
|
||||
time_before = time.perf_counter()
|
||||
success = load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
||||
success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
|
||||
if len(node_import_times) > 0:
|
||||
@@ -2204,7 +2240,7 @@ def init_external_custom_nodes():
|
||||
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
|
||||
logging.info("")
|
||||
|
||||
def init_builtin_extra_nodes():
|
||||
async def init_builtin_extra_nodes():
|
||||
"""
|
||||
Initializes the built-in extra nodes in ComfyUI.
|
||||
|
||||
@@ -2283,18 +2319,18 @@ def init_builtin_extra_nodes():
|
||||
"nodes_string.py",
|
||||
"nodes_camera_trajectory.py",
|
||||
"nodes_edit_model.py",
|
||||
"nodes_tcfg.py"
|
||||
"nodes_tcfg.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
for node_file in extras_files:
|
||||
if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
|
||||
if not await load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
|
||||
import_failed.append(node_file)
|
||||
|
||||
return import_failed
|
||||
|
||||
|
||||
def init_builtin_api_nodes():
|
||||
async def init_builtin_api_nodes():
|
||||
api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes")
|
||||
api_nodes_files = [
|
||||
"nodes_ideogram.py",
|
||||
@@ -2315,26 +2351,35 @@ def init_builtin_api_nodes():
|
||||
"nodes_gemini.py",
|
||||
]
|
||||
|
||||
if not load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):
|
||||
if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):
|
||||
return api_nodes_files
|
||||
|
||||
import_failed = []
|
||||
for node_file in api_nodes_files:
|
||||
if not load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"):
|
||||
if not await load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"):
|
||||
import_failed.append(node_file)
|
||||
|
||||
return import_failed
|
||||
|
||||
async def init_public_apis():
|
||||
register_versions([
|
||||
ComfyAPIWithVersion(
|
||||
version=getattr(v, "VERSION"),
|
||||
api_class=v
|
||||
) for v in supported_versions
|
||||
])
|
||||
|
||||
def init_extra_nodes(init_custom_nodes=True, init_api_nodes=True):
|
||||
import_failed = init_builtin_extra_nodes()
|
||||
async def init_extra_nodes(init_custom_nodes=True, init_api_nodes=True):
|
||||
await init_public_apis()
|
||||
|
||||
import_failed = await init_builtin_extra_nodes()
|
||||
|
||||
import_failed_api = []
|
||||
if init_api_nodes:
|
||||
import_failed_api = init_builtin_api_nodes()
|
||||
import_failed_api = await init_builtin_api_nodes()
|
||||
|
||||
if init_custom_nodes:
|
||||
init_external_custom_nodes()
|
||||
await init_external_custom_nodes()
|
||||
else:
|
||||
logging.info("Skipping loading of custom nodes")
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.47"
|
||||
version = "0.3.49"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
@@ -21,4 +21,4 @@ lint.select = [
|
||||
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||
"F",
|
||||
]
|
||||
exclude = ["*.ipynb"]
|
||||
exclude = ["*.ipynb", "**/generated/*.pyi"]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.23.4
|
||||
comfyui-workflow-templates==0.1.41
|
||||
comfyui-workflow-templates==0.1.51
|
||||
comfyui-embedded-docs==0.2.4
|
||||
torch
|
||||
torchsde
|
||||
|
||||
10
server.py
10
server.py
@@ -30,6 +30,7 @@ from comfy_api import feature_flags
|
||||
import node_helpers
|
||||
from comfyui_version import __version__
|
||||
from app.frontend_management import FrontendManager
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
@@ -591,6 +592,8 @@ class PromptServer():
|
||||
|
||||
def node_info(node_class):
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||
if issubclass(obj_class, _ComfyNodeInternal):
|
||||
return obj_class.GET_NODE_INFO_V1()
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
||||
@@ -681,7 +684,12 @@ class PromptServer():
|
||||
if "prompt" in json_data:
|
||||
prompt = json_data["prompt"]
|
||||
prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
|
||||
valid = await execution.validate_prompt(prompt_id, prompt)
|
||||
|
||||
partial_execution_targets = None
|
||||
if "partial_execution_targets" in json_data:
|
||||
partial_execution_targets = json_data["partial_execution_targets"]
|
||||
|
||||
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
|
||||
extra_data = {}
|
||||
if "extra_data" in json_data:
|
||||
extra_data = json_data["extra_data"]
|
||||
|
||||
@@ -7,7 +7,7 @@ import subprocess
|
||||
|
||||
from pytest import fixture
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from tests.inference.test_execution import ComfyClient
|
||||
from tests.inference.test_execution import ComfyClient, run_warmup
|
||||
|
||||
|
||||
@pytest.mark.execution
|
||||
@@ -24,6 +24,7 @@ class TestAsyncNodes:
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
||||
'--cpu',
|
||||
]
|
||||
use_lru, lru_size = request.param
|
||||
if use_lru:
|
||||
@@ -82,6 +83,9 @@ class TestAsyncNodes:
|
||||
|
||||
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that multiple async nodes execute in parallel."""
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client)
|
||||
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
@@ -148,6 +152,9 @@ class TestAsyncNodes:
|
||||
|
||||
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes with lazy evaluation."""
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client, prefix="warmup_lazy")
|
||||
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
@@ -305,6 +312,9 @@ class TestAsyncNodes:
|
||||
|
||||
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that async nodes are properly cached."""
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client, prefix="warmup_cache")
|
||||
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
|
||||
@@ -324,6 +334,9 @@ class TestAsyncNodes:
|
||||
|
||||
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes within dynamically generated prompts."""
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client, prefix="warmup_dynamic")
|
||||
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
@@ -15,10 +15,18 @@ import urllib.parse
|
||||
import urllib.error
|
||||
from comfy_execution.graph_utils import GraphBuilder, Node
|
||||
|
||||
def run_warmup(client, prefix="warmup"):
|
||||
"""Run a simple workflow to warm up the server."""
|
||||
warmup_g = GraphBuilder(prefix=prefix)
|
||||
warmup_image = warmup_g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1)
|
||||
warmup_g.node("PreviewImage", images=warmup_image.out(0))
|
||||
client.run(warmup_g)
|
||||
|
||||
class RunResult:
|
||||
def __init__(self, prompt_id: str):
|
||||
self.outputs: Dict[str,Dict] = {}
|
||||
self.runs: Dict[str,bool] = {}
|
||||
self.cached: Dict[str,bool] = {}
|
||||
self.prompt_id: str = prompt_id
|
||||
|
||||
def get_output(self, node: Node):
|
||||
@@ -27,6 +35,13 @@ class RunResult:
|
||||
def did_run(self, node: Node):
|
||||
return self.runs.get(node.id, False)
|
||||
|
||||
def was_cached(self, node: Node):
|
||||
return self.cached.get(node.id, False)
|
||||
|
||||
def was_executed(self, node: Node):
|
||||
"""Returns True if node was either run or cached"""
|
||||
return self.did_run(node) or self.was_cached(node)
|
||||
|
||||
def get_images(self, node: Node):
|
||||
output = self.get_output(node)
|
||||
if output is None:
|
||||
@@ -51,8 +66,10 @@ class ComfyClient:
|
||||
ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
|
||||
self.ws = ws
|
||||
|
||||
def queue_prompt(self, prompt):
|
||||
def queue_prompt(self, prompt, partial_execution_targets=None):
|
||||
p = {"prompt": prompt, "client_id": self.client_id}
|
||||
if partial_execution_targets is not None:
|
||||
p["partial_execution_targets"] = partial_execution_targets
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
|
||||
return json.loads(urllib.request.urlopen(req).read())
|
||||
@@ -70,13 +87,13 @@ class ComfyClient:
|
||||
def set_test_name(self, name):
|
||||
self.test_name = name
|
||||
|
||||
def run(self, graph):
|
||||
def run(self, graph, partial_execution_targets=None):
|
||||
prompt = graph.finalize()
|
||||
for node in graph.nodes.values():
|
||||
if node.class_type == 'SaveImage':
|
||||
node.inputs['filename_prefix'] = self.test_name
|
||||
|
||||
prompt_id = self.queue_prompt(prompt)['prompt_id']
|
||||
prompt_id = self.queue_prompt(prompt, partial_execution_targets)['prompt_id']
|
||||
result = RunResult(prompt_id)
|
||||
while True:
|
||||
out = self.ws.recv()
|
||||
@@ -92,7 +109,10 @@ class ComfyClient:
|
||||
elif message['type'] == 'execution_error':
|
||||
raise Exception(message['data'])
|
||||
elif message['type'] == 'execution_cached':
|
||||
pass # Probably want to store this off for testing
|
||||
if message['data']['prompt_id'] == prompt_id:
|
||||
cached_nodes = message['data'].get('nodes', [])
|
||||
for node_id in cached_nodes:
|
||||
result.cached[node_id] = True
|
||||
|
||||
history = self.get_history(prompt_id)[prompt_id]
|
||||
for node_id in history['outputs']:
|
||||
@@ -130,6 +150,7 @@ class TestExecution:
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
||||
'--cpu',
|
||||
]
|
||||
use_lru, lru_size = request.param
|
||||
if use_lru:
|
||||
@@ -498,12 +519,15 @@ class TestExecution:
|
||||
assert not result.did_run(test_node), "The execution should have been cached"
|
||||
|
||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client)
|
||||
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create sleep nodes for each duration
|
||||
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8)
|
||||
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9)
|
||||
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.9)
|
||||
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=3.1)
|
||||
sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0)
|
||||
|
||||
# Add outputs to verify the execution
|
||||
@@ -515,10 +539,9 @@ class TestExecution:
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# The test should take around 0.4 seconds (the longest sleep duration)
|
||||
# plus some overhead, but definitely less than the sum of all sleeps (0.9s)
|
||||
# We'll allow for up to 0.8s total to account for overhead
|
||||
assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s"
|
||||
# The test should take around 3.0 seconds (the longest sleep duration)
|
||||
# plus some overhead, but definitely less than the sum of all sleeps (9.0s)
|
||||
assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s"
|
||||
|
||||
# Verify that all nodes executed
|
||||
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
|
||||
@@ -526,6 +549,9 @@ class TestExecution:
|
||||
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
|
||||
|
||||
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client)
|
||||
|
||||
g = builder
|
||||
# Create input images with different values
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
@@ -537,9 +563,9 @@ class TestExecution:
|
||||
image1=image1.out(0),
|
||||
image2=image2.out(0),
|
||||
image3=image3.out(0),
|
||||
sleep1=0.4,
|
||||
sleep2=0.5,
|
||||
sleep3=0.6)
|
||||
sleep1=4.8,
|
||||
sleep2=4.9,
|
||||
sleep3=5.0)
|
||||
output = g.node("SaveImage", images=parallel_sleep.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
@@ -548,7 +574,7 @@ class TestExecution:
|
||||
|
||||
# Similar to the previous test, expect parallel execution of the sleep nodes
|
||||
# which should complete in less than the sum of all sleeps
|
||||
assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s"
|
||||
assert elapsed_time < 10.0, f"Expansion execution took {elapsed_time}s, expected less than 5.5s"
|
||||
|
||||
# Verify the parallel sleep node executed
|
||||
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
|
||||
@@ -585,3 +611,151 @@ class TestExecution:
|
||||
assert len(images) == 2, "Should have 2 images"
|
||||
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
|
||||
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"
|
||||
|
||||
# Output nodes included in the partial execution list are executed
|
||||
def test_partial_execution_included_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create two separate output nodes
|
||||
output1 = g.node("SaveImage", images=input1.out(0))
|
||||
output2 = g.node("SaveImage", images=input2.out(0))
|
||||
|
||||
# Run with partial execution targeting only output1
|
||||
result = client.run(g, partial_execution_targets=[output1.id])
|
||||
|
||||
assert result.was_executed(input1), "Input1 should have been executed (run or cached)"
|
||||
assert result.was_executed(output1), "Output1 should have been executed (run or cached)"
|
||||
assert not result.did_run(input2), "Input2 should not have run"
|
||||
assert not result.did_run(output2), "Output2 should not have run"
|
||||
|
||||
# Verify only output1 produced results
|
||||
assert len(result.get_images(output1)) == 1, "Output1 should have produced an image"
|
||||
assert len(result.get_images(output2)) == 0, "Output2 should not have produced an image"
|
||||
|
||||
# Output nodes NOT included in the partial execution list are NOT executed
|
||||
def test_partial_execution_excluded_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create three output nodes
|
||||
output1 = g.node("SaveImage", images=input1.out(0))
|
||||
output2 = g.node("SaveImage", images=input2.out(0))
|
||||
output3 = g.node("SaveImage", images=input3.out(0))
|
||||
|
||||
# Run with partial execution targeting only output1 and output3
|
||||
result = client.run(g, partial_execution_targets=[output1.id, output3.id])
|
||||
|
||||
assert result.was_executed(input1), "Input1 should have been executed"
|
||||
assert result.was_executed(input3), "Input3 should have been executed"
|
||||
assert result.was_executed(output1), "Output1 should have been executed"
|
||||
assert result.was_executed(output3), "Output3 should have been executed"
|
||||
assert not result.did_run(input2), "Input2 should not have run"
|
||||
assert not result.did_run(output2), "Output2 should not have run"
|
||||
|
||||
# Output nodes NOT in list ARE executed if necessary for nodes that are in the list
|
||||
def test_partial_execution_dependencies(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create a processing chain with an OUTPUT_NODE that has socket outputs
|
||||
output_with_socket = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=2.0)
|
||||
|
||||
# Create another node that depends on the output_with_socket
|
||||
dependent_node = g.node("TestLazyMixImages",
|
||||
image1=output_with_socket.out(0),
|
||||
image2=input1.out(0),
|
||||
mask=g.node("StubMask", value=0.5, height=512, width=512, batch_size=1).out(0))
|
||||
|
||||
# Create the final output
|
||||
final_output = g.node("SaveImage", images=dependent_node.out(0))
|
||||
|
||||
# Run with partial execution targeting only the final output
|
||||
result = client.run(g, partial_execution_targets=[final_output.id])
|
||||
|
||||
# All nodes should have been executed because they're dependencies
|
||||
assert result.was_executed(input1), "Input1 should have been executed"
|
||||
assert result.was_executed(output_with_socket), "Output with socket should have been executed (dependency)"
|
||||
assert result.was_executed(dependent_node), "Dependent node should have been executed"
|
||||
assert result.was_executed(final_output), "Final output should have been executed"
|
||||
|
||||
# Lazy execution works with partial execution
|
||||
def test_partial_execution_with_lazy_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create masks that will trigger different lazy execution paths
|
||||
mask1 = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) # Will only need image1
|
||||
mask2 = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) # Will need both images
|
||||
|
||||
# Create two lazy mix nodes
|
||||
lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask1.out(0))
|
||||
lazy_mix2 = g.node("TestLazyMixImages", image1=input2.out(0), image2=input3.out(0), mask=mask2.out(0))
|
||||
|
||||
output1 = g.node("SaveImage", images=lazy_mix1.out(0))
|
||||
output2 = g.node("SaveImage", images=lazy_mix2.out(0))
|
||||
|
||||
# Run with partial execution targeting only output1
|
||||
result = client.run(g, partial_execution_targets=[output1.id])
|
||||
|
||||
# For output1 path - only input1 should run due to lazy evaluation (mask=0.0)
|
||||
assert result.was_executed(input1), "Input1 should have been executed"
|
||||
assert not result.did_run(input2), "Input2 should not have run (lazy evaluation)"
|
||||
assert result.was_executed(mask1), "Mask1 should have been executed"
|
||||
assert result.was_executed(lazy_mix1), "Lazy mix1 should have been executed"
|
||||
assert result.was_executed(output1), "Output1 should have been executed"
|
||||
|
||||
# Nothing from output2 path should run
|
||||
assert not result.did_run(input3), "Input3 should not have run"
|
||||
assert not result.did_run(mask2), "Mask2 should not have run"
|
||||
assert not result.did_run(lazy_mix2), "Lazy mix2 should not have run"
|
||||
assert not result.did_run(output2), "Output2 should not have run"
|
||||
|
||||
# Multiple OUTPUT_NODEs with dependencies
|
||||
def test_partial_execution_multiple_output_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create a chain of OUTPUT_NODEs
|
||||
output_node1 = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=1.5)
|
||||
output_node2 = g.node("TestOutputNodeWithSocketOutput", image=output_node1.out(0), value=2.0)
|
||||
|
||||
# Create regular output nodes
|
||||
save1 = g.node("SaveImage", images=output_node1.out(0))
|
||||
save2 = g.node("SaveImage", images=output_node2.out(0))
|
||||
save3 = g.node("SaveImage", images=input2.out(0))
|
||||
|
||||
# Run targeting only save2
|
||||
result = client.run(g, partial_execution_targets=[save2.id])
|
||||
|
||||
# Should run: input1, output_node1, output_node2, save2
|
||||
assert result.was_executed(input1), "Input1 should have been executed"
|
||||
assert result.was_executed(output_node1), "Output node 1 should have been executed (dependency)"
|
||||
assert result.was_executed(output_node2), "Output node 2 should have been executed (dependency)"
|
||||
assert result.was_executed(save2), "Save2 should have been executed"
|
||||
|
||||
# Should NOT run: input2, save1, save3
|
||||
assert not result.did_run(input2), "Input2 should not have run"
|
||||
assert not result.did_run(save1), "Save1 should not have run"
|
||||
assert not result.did_run(save3), "Save3 should not have run"
|
||||
|
||||
# Empty partial execution list (should execute nothing)
|
||||
def test_partial_execution_empty_list(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
_output1 = g.node("SaveImage", images=input1.out(0))
|
||||
|
||||
# Run with empty partial execution list
|
||||
try:
|
||||
_result = client.run(g, partial_execution_targets=[])
|
||||
# Should get an error because no outputs are selected
|
||||
assert False, "Should have raised an error for empty partial execution list"
|
||||
except urllib.error.HTTPError:
|
||||
pass # Expected behavior
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPING
|
||||
from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .async_test_nodes import ASYNC_TEST_NODE_CLASS_MAPPINGS, ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .api_test_nodes import API_TEST_NODE_CLASS_MAPPINGS, API_TEST_NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
|
||||
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
@@ -15,6 +16,7 @@ NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(ASYNC_TEST_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(API_TEST_NODE_CLASS_MAPPINGS)
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
@@ -23,4 +25,4 @@ NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(API_TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
|
||||
78
tests/inference/testing_nodes/testing-pack/api_test_nodes.py
Normal file
78
tests/inference/testing_nodes/testing-pack/api_test_nodes.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import asyncio
|
||||
import time
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
from comfy_api.v0_0_2 import ComfyAPI, ComfyAPISync
|
||||
|
||||
api = ComfyAPI()
|
||||
api_sync = ComfyAPISync()
|
||||
|
||||
|
||||
class TestAsyncProgressUpdate(ComfyNodeABC):
|
||||
"""Test node with async VALIDATE_INPUTS."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"sleep_seconds": (IO.FLOAT, {"default": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def execute(self, value, sleep_seconds):
|
||||
start = time.time()
|
||||
expiration = start + sleep_seconds
|
||||
now = start
|
||||
while now < expiration:
|
||||
now = time.time()
|
||||
await api.execution.set_progress(
|
||||
value=(now - start) / sleep_seconds,
|
||||
max_value=1.0,
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
return (value,)
|
||||
|
||||
|
||||
class TestSyncProgressUpdate(ComfyNodeABC):
|
||||
"""Test node with async VALIDATE_INPUTS."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"sleep_seconds": (IO.FLOAT, {"default": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
def execute(self, value, sleep_seconds):
|
||||
start = time.time()
|
||||
expiration = start + sleep_seconds
|
||||
now = start
|
||||
while now < expiration:
|
||||
now = time.time()
|
||||
api_sync.execution.set_progress(
|
||||
value=(now - start) / sleep_seconds,
|
||||
max_value=1.0,
|
||||
)
|
||||
time.sleep(0.01)
|
||||
return (value,)
|
||||
|
||||
|
||||
API_TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestAsyncProgressUpdate": TestAsyncProgressUpdate,
|
||||
"TestSyncProgressUpdate": TestSyncProgressUpdate,
|
||||
}
|
||||
|
||||
API_TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestAsyncProgressUpdate": "Async Progress Update Test Node",
|
||||
"TestSyncProgressUpdate": "Sync Progress Update Test Node",
|
||||
}
|
||||
@@ -463,6 +463,25 @@ class TestParallelSleep(ComfyNodeABC):
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
class TestOutputNodeWithSocketOutput:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def process(self, image, value):
|
||||
# Apply value scaling and return both as output and socket
|
||||
result = image * value
|
||||
return (result,)
|
||||
|
||||
TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestLazyMixImages": TestLazyMixImages,
|
||||
"TestVariadicAverage": TestVariadicAverage,
|
||||
@@ -478,6 +497,7 @@ TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestSamplingInExpansion": TestSamplingInExpansion,
|
||||
"TestSleep": TestSleep,
|
||||
"TestParallelSleep": TestParallelSleep,
|
||||
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
|
||||
}
|
||||
|
||||
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -495,4 +515,5 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestSamplingInExpansion": "Sampling In Expansion",
|
||||
"TestSleep": "Test Sleep",
|
||||
"TestParallelSleep": "Test Parallel Sleep",
|
||||
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user