Compare commits

..

6 Commits

Author SHA1 Message Date
Jacob Segal
e9db9554bd ACTUALLY skip timing checks in CI 2025-09-03 12:30:44 -07:00
Jacob Segal
98be8e1969 Whoops, reverted a change in the last commit 2025-09-03 12:21:13 -07:00
Jacob Segal
766ff74207 Just disable timing-related assertions in CI
That way there's no risk of periodic non-deterministic test failures.
2025-09-03 12:18:48 -07:00
Jacob Segal
b1b5f87534 Add more leeway on async tests for Windows CI 2025-09-02 23:43:15 -07:00
Jacob Segal
cf45fd1742 Add missing test modules 2025-09-02 23:21:43 -07:00
Jacob Segal
e7314f49e6 Fix showing progress from other sessions
Because `client_id` was missing from ths `progress_state` message, it
was being sent to all connected sessions. This technically meant that if
someone had a graph with the same nodes, they would see the progress
updates for others.

Also added a test to prevent reoccurance and moved the tests around to
make CI easier to hook up.
2025-09-02 23:17:26 -07:00
52 changed files with 599 additions and 6042 deletions

View File

@@ -145,7 +145,7 @@ class PerformanceFeature(enum.Enum):
CublasOps = "cublas_ops"
AutoTune = "autotune"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")

View File

@@ -61,12 +61,8 @@ class CLIPEncoder(torch.nn.Module):
def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
all_intermediate = None
if intermediate_output is not None:
if intermediate_output == "all":
all_intermediate = []
intermediate_output = None
elif intermediate_output < 0:
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
@@ -74,12 +70,6 @@ class CLIPEncoder(torch.nn.Module):
x = l(x, mask, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
if all_intermediate is not None:
intermediate = torch.cat(all_intermediate, dim=1)
return x, intermediate
class CLIPEmbeddings(torch.nn.Module):

View File

@@ -50,13 +50,7 @@ class ClipVisionModel():
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
model_type = config.get("model_type", "clip_vision_model")
model_class = IMAGE_ENCODERS.get(model_type)
if model_type == "siglip_vision_model":
self.return_all_hidden_states = True
else:
self.return_all_hidden_states = False
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
@@ -74,18 +68,12 @@ class ClipVisionModel():
def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
outputs = Output()
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
if self.return_all_hidden_states:
all_hs = out[1].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = all_hs[:, -2]
outputs["all_hidden_states"] = all_hs
else:
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
outputs["mm_projected"] = out[3]
return outputs
@@ -136,12 +124,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
# Dinov2
elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
elif "embeddings.patch_embeddings.projection.weight" in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
else:
return None

View File

@@ -146,13 +146,11 @@ class IndexListContextHandler(ContextHandlerABC):
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
for cond_key, cond_value in new_cond_item.items():
if isinstance(cond_value, torch.Tensor):
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
# if has cond that is a Tensor, check if needs to be subset
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
(cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
elif cond_key == "num_video_frames": # for SVD
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)

View File

@@ -253,10 +253,7 @@ class ControlNet(ControlBase):
to_concat = []
for c in self.extra_concat_orig:
c = c.to(self.cond_hint.device)
c = comfy.utils.common_upscale(c, self.cond_hint.shape[-1], self.cond_hint.shape[-2], self.upscale_algorithm, "center")
if c.ndim < self.cond_hint.ndim:
c = c.unsqueeze(2)
c = comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[2], dim=2)
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
@@ -588,18 +585,11 @@ def load_controlnet_flux_instantx(sd, model_options={}):
def load_controlnet_qwen_instantx(sd, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_latent_channels = sd.get("controlnet_x_embedder.weight").shape[1]
extra_condition_channels = 0
concat_mask = False
if control_latent_channels == 68: #inpaint controlnet
extra_condition_channels = control_latent_channels - 64
concat_mask = True
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(extra_condition_channels=extra_condition_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
latent_format = comfy.latent_formats.Wan21()
extra_conds = []
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def convert_mistoline(sd):

View File

@@ -31,20 +31,6 @@ class LayerScale(torch.nn.Module):
def forward(self, x):
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
class Dinov2MLP(torch.nn.Module):
def __init__(self, hidden_size: int, dtype, device, operations):
super().__init__()
mlp_ratio = 4
hidden_features = int(hidden_size * mlp_ratio)
self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype)
self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.fc1(hidden_state)
hidden_state = torch.nn.functional.gelu(hidden_state)
hidden_state = self.fc2(hidden_state)
return hidden_state
class SwiGLUFFN(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
@@ -64,15 +50,12 @@ class SwiGLUFFN(torch.nn.Module):
class Dino2Block(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
if use_swiglu_ffn:
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
else:
self.mlp = Dinov2MLP(dim, dtype, device, operations)
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
@@ -83,10 +66,9 @@ class Dino2Block(torch.nn.Module):
class Dino2Encoder(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
super().__init__()
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
for _ in range(num_layers)])
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
def forward(self, x, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
@@ -96,8 +78,8 @@ class Dino2Encoder(torch.nn.Module):
intermediate_output = len(self.layer) + intermediate_output
intermediate = None
for i, layer in enumerate(self.layer):
x = layer(x, optimized_attention)
for i, l in enumerate(self.layer):
x = l(x, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
@@ -146,10 +128,9 @@ class Dinov2Model(torch.nn.Module):
dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
layer_norm_eps = config_dict["layer_norm_eps"]
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):

View File

@@ -1,22 +0,0 @@
{
"hidden_size": 1024,
"use_mask_token": true,
"patch_size": 14,
"image_size": 518,
"num_channels": 3,
"num_attention_heads": 16,
"initializer_range": 0.02,
"attention_probs_dropout_prob": 0.0,
"hidden_dropout_prob": 0.0,
"hidden_act": "gelu",
"mlp_ratio": 4,
"model_type": "dinov2",
"num_hidden_layers": 24,
"layer_norm_eps": 1e-6,
"qkv_bias": true,
"use_swiglu_ffn": false,
"layerscale_value": 1.0,
"drop_path_rate": 0.0,
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225]
}

View File

@@ -533,89 +533,11 @@ class Wan22(Wan21):
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
]).view(1, self.latent_channels, 1, 1, 1)
class HunyuanImage21(LatentFormat):
latent_channels = 64
latent_dimensions = 2
scale_factor = 0.75289
latent_rgb_factors = [
[-0.0154, -0.0397, -0.0521],
[ 0.0005, 0.0093, 0.0006],
[-0.0805, -0.0773, -0.0586],
[-0.0494, -0.0487, -0.0498],
[-0.0212, -0.0076, -0.0261],
[-0.0179, -0.0417, -0.0505],
[ 0.0158, 0.0310, 0.0239],
[ 0.0409, 0.0516, 0.0201],
[ 0.0350, 0.0553, 0.0036],
[-0.0447, -0.0327, -0.0479],
[-0.0038, -0.0221, -0.0365],
[-0.0423, -0.0718, -0.0654],
[ 0.0039, 0.0368, 0.0104],
[ 0.0655, 0.0217, 0.0122],
[ 0.0490, 0.1638, 0.2053],
[ 0.0932, 0.0829, 0.0650],
[-0.0186, -0.0209, -0.0135],
[-0.0080, -0.0076, -0.0148],
[-0.0284, -0.0201, 0.0011],
[-0.0642, -0.0294, -0.0777],
[-0.0035, 0.0076, -0.0140],
[ 0.0519, 0.0731, 0.0887],
[-0.0102, 0.0095, 0.0704],
[ 0.0068, 0.0218, -0.0023],
[-0.0726, -0.0486, -0.0519],
[ 0.0260, 0.0295, 0.0263],
[ 0.0250, 0.0333, 0.0341],
[ 0.0168, -0.0120, -0.0174],
[ 0.0226, 0.1037, 0.0114],
[ 0.2577, 0.1906, 0.1604],
[-0.0646, -0.0137, -0.0018],
[-0.0112, 0.0309, 0.0358],
[-0.0347, 0.0146, -0.0481],
[ 0.0234, 0.0179, 0.0201],
[ 0.0157, 0.0313, 0.0225],
[ 0.0423, 0.0675, 0.0524],
[-0.0031, 0.0027, -0.0255],
[ 0.0447, 0.0555, 0.0330],
[-0.0152, 0.0103, 0.0299],
[-0.0755, -0.0489, -0.0635],
[ 0.0853, 0.0788, 0.1017],
[-0.0272, -0.0294, -0.0471],
[ 0.0440, 0.0400, -0.0137],
[ 0.0335, 0.0317, -0.0036],
[-0.0344, -0.0621, -0.0984],
[-0.0127, -0.0630, -0.0620],
[-0.0648, 0.0360, 0.0924],
[-0.0781, -0.0801, -0.0409],
[ 0.0363, 0.0613, 0.0499],
[ 0.0238, 0.0034, 0.0041],
[-0.0135, 0.0258, 0.0310],
[ 0.0614, 0.1086, 0.0589],
[ 0.0428, 0.0350, 0.0205],
[ 0.0153, 0.0173, -0.0018],
[-0.0288, -0.0455, -0.0091],
[ 0.0344, 0.0109, -0.0157],
[-0.0205, -0.0247, -0.0187],
[ 0.0487, 0.0126, 0.0064],
[-0.0220, -0.0013, 0.0074],
[-0.0203, -0.0094, -0.0048],
[-0.0719, 0.0429, -0.0442],
[ 0.1042, 0.0497, 0.0356],
[-0.0659, -0.0578, -0.0280],
[-0.0060, -0.0322, -0.0234]]
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 0.9990943042622529
class Hunyuan3Dv2_1(LatentFormat):
scale_factor = 1.0039506158752403
latent_channels = 64
latent_dimensions = 1
class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64
latent_dimensions = 1

View File

@@ -632,7 +632,7 @@ class ContinuousTransformer(nn.Module):
# Attention layers
if self.rotary_pos_emb is not None:
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=torch.float, device=x.device)
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
else:
rotary_pos_emb = None

View File

@@ -106,7 +106,6 @@ class Flux(nn.Module):
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -118,17 +117,9 @@ class Flux(nn.Module):
if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt)
if "post_input" in patches:
for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
img = out["img"]
txt = out["txt"]
img_ids = out["img_ids"]
txt_ids = out["txt_ids"]
if img_ids is not None:
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
@@ -248,7 +239,7 @@ class Flux(nn.Module):
index += 1
h_offset = 0
w_offset = 0
elif ref_latents_method == "uxo":
elif ref_latents_method == "uso":
index = 0
h_offset = h_len * patch_size + h
w_offset = w_len * patch_size + w

View File

@@ -4,458 +4,81 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union, Tuple, List, Callable, Optional
import numpy as np
import math
from einops import repeat, rearrange
from tqdm import tqdm
from typing import Optional
import logging
import comfy.ops
ops = comfy.ops.disable_weight_init
def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True):
# manually create the pointer vector
assert src.size(0) == batch.numel()
batch_size = int(batch.max()) + 1
deg = src.new_zeros(batch_size, dtype = torch.long)
deg.scatter_add_(0, batch, torch.ones_like(batch))
ptr_vec = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr_vec[1:])
#return fps_sampling(src, ptr_vec, ratio)
sampled_indicies = []
for b in range(batch_size):
# start and the end of each batch
start, end = ptr_vec[b].item(), ptr_vec[b + 1].item()
# points from the point cloud
points = src[start:end]
num_points = points.size(0)
num_samples = max(1, math.ceil(num_points * sampling_ratio))
selected = torch.zeros(num_samples, device = src.device, dtype = torch.long)
distances = torch.full((num_points,), float("inf"), device = src.device)
# select a random start point
if start_random:
farthest = torch.randint(0, num_points, (1,), device = src.device)
else:
farthest = torch.tensor([0], device = src.device, dtype = torch.long)
for i in range(num_samples):
selected[i] = farthest
centroid = points[farthest].squeeze(0)
dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance
distances = torch.minimum(distances, dist)
farthest = torch.argmax(distances)
sampled_indicies.append(torch.arange(start, end)[selected])
return torch.cat(sampled_indicies, dim = 0)
class PointCrossAttention(nn.Module):
def __init__(self,
num_latents: int,
downsample_ratio: float,
pc_size: int,
pc_sharpedge_size: int,
point_feats: int,
width: int,
heads: int,
layers: int,
fourier_embedder,
normal_pe: bool = False,
qkv_bias: bool = False,
use_ln_post: bool = True,
qk_norm: bool = True):
super().__init__()
self.fourier_embedder = fourier_embedder
self.pc_size = pc_size
self.normal_pe = normal_pe
self.downsample_ratio = downsample_ratio
self.pc_sharpedge_size = pc_sharpedge_size
self.num_latents = num_latents
self.point_feats = point_feats
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
self.cross_attn = ResidualCrossAttentionBlock(
width = width,
heads = heads,
qkv_bias = qkv_bias,
qk_norm = qk_norm
)
self.self_attn = None
if layers > 0:
self.self_attn = Transformer(
width = width,
heads = heads,
qkv_bias = qkv_bias,
qk_norm = qk_norm,
layers = layers
)
if use_ln_post:
self.ln_post = nn.LayerNorm(width)
else:
self.ln_post = None
def sample_points_and_latents(self, point_cloud: torch.Tensor, features: torch.Tensor):
"""
Subsample points randomly from the point cloud (input_pc)
Further sample the subsampled points to get query_pc
take the fourier embeddings for both input and query pc
Mental Note: FPS-sampled points (query_pc) act as latent tokens that attend to and learn from the broader context in input_pc.
Goal: get a smaller represenation (query_pc) to represent the entire scence structure by learning from a broader subset (input_pc).
More computationally efficient.
Features are additional information for each point in the cloud
"""
B, _, D = point_cloud.shape
num_latents = int(self.num_latents)
num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
num_sharpedge_query = num_latents - num_random_query
# Split random and sharpedge surface points
random_pc, sharpedge_pc = torch.split(point_cloud, [self.pc_size, self.pc_sharpedge_size], dim=1)
# assert statements
assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size"
assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size"
input_random_pc_size = int(num_random_query * self.downsample_ratio)
random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \
self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size)
input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
if input_sharpedge_pc_size == 0:
sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device)
sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device)
else:
sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \
self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size)
# concat the random and sharpedges
query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1)
input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1)
query = self.fourier_embedder(query_pc)
data = self.fourier_embedder(input_pc)
if self.point_feats > 0:
random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1)
input_random_surface_features, query_random_features = \
self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B,
input_pc_size = input_random_pc_size, idx_query = random_idx_query)
if input_sharpedge_pc_size == 0:
input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats,
dtype = input_random_surface_features.dtype, device = point_cloud.device)
query_sharpedge_features = torch.zeros(B, 0, self.point_feats,
dtype = query_random_features.dtype, device = point_cloud.device)
else:
input_sharpedge_surface_features, query_sharpedge_features = \
self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features,
batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size)
query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1)
input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1)
if self.normal_pe:
# apply the fourier embeddings on the first 3 dims (xyz)
input_features_pe = self.fourier_embedder(input_features[..., :3])
query_features_pe = self.fourier_embedder(query_features[..., :3])
# replace the first 3 dims with the new PE ones
input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1)
query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1)
# concat at the channels dim
query = torch.cat([query, query_features], dim = -1)
data = torch.cat([data, input_features], dim = -1)
# don't return pc_info to avoid unnecessary memory usuage
return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1])
def forward(self, point_cloud: torch.Tensor, features: torch.Tensor):
query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features)
# apply projections
query = self.input_proj(query)
data = self.input_proj(data)
# apply cross attention between query and data
latents = self.cross_attn(query, data)
if self.self_attn is not None:
latents = self.self_attn(latents)
if self.ln_post is not None:
latents = self.ln_post(latents)
return latents
def subsample(self, pc, num_query, input_pc_size: int):
"""
num_query: number of points to keep after FPS
input_pc_size: number of points to select before FPS
"""
B, _, D = pc.shape
query_ratio = num_query / input_pc_size
# random subsampling of points inside the point cloud
idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size]
input_pc = pc[:, idx_pc, :]
# flatten to allow applying fps across the whole batch
flattent_input_pc = input_pc.view(B * input_pc_size, D)
# construct a batch_down tensor to tell fps
# which points belong to which batch
N_down = int(flattent_input_pc.shape[0] / B)
batch_down = torch.arange(B).to(pc.device)
batch_down = torch.repeat_interleave(batch_down, N_down)
idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio)
query_pc = flattent_input_pc[idx_query].view(B, -1, D)
return query_pc, input_pc, idx_pc, idx_query
def handle_features(self, features, idx_pc, input_pc_size, batch_size: int, idx_query):
B = batch_size
input_surface_features = features[:, idx_pc, :]
flattent_input_features = input_surface_features.view(B * input_pc_size, -1)
query_features = flattent_input_features[idx_query].view(B, -1,
flattent_input_features.shape[-1])
return input_surface_features, query_features
def normalize_mesh(mesh, scale = 0.9999):
"""Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]"""
bbox = mesh.bounds
center = (bbox[1] + bbox[0]) / 2
max_extent = (bbox[1] - bbox[0]).max()
mesh.apply_translation(-center)
mesh.apply_scale((2 * scale) / max_extent)
return mesh
def sample_pointcloud(mesh, num = 200000):
""" Uniformly sample points from the surface of the mesh """
points, face_idx = mesh.sample(num, return_index = True)
normals = mesh.face_normals[face_idx]
return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32))
def detect_sharp_edges(mesh, threshold=0.985):
"""Return edge indices (a, b) that lie on sharp boundaries of the mesh."""
V, F = mesh.vertices, mesh.faces
VN, FN = mesh.vertex_normals, mesh.face_normals
sharp_mask = np.ones(V.shape[0])
for i in range(3):
indices = F[:, i]
alignment = np.einsum('ij,ij->i', VN[indices], FN)
dot_stack = np.stack((sharp_mask[indices], alignment), axis=-1)
sharp_mask[indices] = np.min(dot_stack, axis=-1)
edge_a = np.concatenate([F[:, 0], F[:, 1], F[:, 2]])
edge_b = np.concatenate([F[:, 1], F[:, 2], F[:, 0]])
sharp_edges = (sharp_mask[edge_a] < threshold) & (sharp_mask[edge_b] < threshold)
return edge_a[sharp_edges], edge_b[sharp_edges]
def sharp_sample_pointcloud(mesh, num = 16384):
""" Sample points preferentially from sharp edges in the mesh. """
edge_a, edge_b = detect_sharp_edges(mesh)
V, VN = mesh.vertices, mesh.vertex_normals
va, vb = V[edge_a], V[edge_b]
na, nb = VN[edge_a], VN[edge_b]
edge_lengths = np.linalg.norm(vb - va, axis=-1)
weights = edge_lengths / edge_lengths.sum()
indices = np.searchsorted(np.cumsum(weights), np.random.rand(num))
t = np.random.rand(num, 1)
samples = t * va[indices] + (1 - t) * vb[indices]
normals = t * na[indices] + (1 - t) * nb[indices]
return samples.astype(np.float32), normals.astype(np.float32)
def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"):
"""Load a surface with optional sharp-edge annotations from a trimesh mesh."""
import trimesh
try:
mesh_full = trimesh.util.concatenate(mesh.dump())
except Exception:
mesh_full = trimesh.util.concatenate(mesh)
mesh_full = normalize_mesh(mesh_full)
faces = mesh_full.faces
vertices = mesh_full.vertices
origin_face_count = faces.shape[0]
mesh_surface = trimesh.Trimesh(vertices=vertices, faces=faces[:origin_face_count])
mesh_fill = trimesh.Trimesh(vertices=vertices, faces=faces[origin_face_count:])
area_surface = mesh_surface.area
area_fill = mesh_fill.area
total_area = area_surface + area_fill
sample_num = 499712 // 2
fill_ratio = area_fill / total_area if total_area > 0 else 0
num_fill = int(sample_num * fill_ratio)
num_surface = sample_num - num_fill
surf_pts, surf_normals = sample_pointcloud(mesh_surface, num_surface)
fill_pts, fill_normals = (torch.zeros(0, 3), torch.zeros(0, 3)) if num_fill == 0 else sample_pointcloud(mesh_fill, num_fill)
sharp_pts, sharp_normals = sharp_sample_pointcloud(mesh_surface, sample_num)
def assemble_tensor(points, normals, label=None):
data = torch.cat([points, normals], dim=1).half().to(device)
if label is not None:
label_tensor = torch.full((data.shape[0], 1), float(label), dtype=torch.float16).to(device)
data = torch.cat([data, label_tensor], dim=1)
return data
surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0),
torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0),
label = 0 if sharpedge_flag else None)
sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals),
label = 1 if sharpedge_flag else None)
rng = np.random.default_rng()
surface = surface[rng.choice(surface.shape[0], num_points, replace = False)]
sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)]
full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0)
return full
class SharpEdgeSurfaceLoader:
""" Load mesh surface and sharp edge samples. """
def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192):
self.num_uniform_points = num_uniform_points
self.num_sharp_points = num_sharp_points
self.total_points = num_uniform_points + num_sharp_points
def __call__(self, mesh_input, device = "cuda"):
mesh = self._load_mesh(mesh_input)
return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device)
@staticmethod
def _load_mesh(mesh_input):
import trimesh
if isinstance(mesh_input, str):
mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True)
else:
mesh = mesh_input
if isinstance(mesh, trimesh.Scene):
combined = None
for obj in mesh.geometry.values():
combined = obj if combined is None else combined + obj
return combined
return mesh
class DiagonalGaussianDistribution:
def __init__(self, params: torch.Tensor, feature_dim: int = -1):
# divide quant channels (8) into mean and log variance
self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.std = torch.exp(0.5 * self.logvar)
def sample(self):
eps = torch.randn_like(self.std)
z = self.mean + eps * self.std
return z
################################################
# Volume Decoder
################################################
class VanillaVolumeDecoder():
def generate_dense_grid_points(
bbox_min: np.ndarray,
bbox_max: np.ndarray,
octree_resolution: int,
indexing: str = "ij",
):
length = bbox_max - bbox_min
num_cells = octree_resolution
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
xyz = np.stack((xs, ys, zs), axis=-1)
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
return xyz, grid_size, length
class VanillaVolumeDecoder:
@torch.no_grad()
def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01,
num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs):
def __call__(
self,
latents: torch.FloatTensor,
geo_decoder: Callable,
bounds: Union[Tuple[float], List[float], float] = 1.01,
num_chunks: int = 10000,
octree_resolution: int = None,
enable_pbar: bool = True,
**kwargs,
):
device = latents.device
dtype = latents.dtype
batch_size = latents.shape[0]
# 1. generate query points
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:])
x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32)
y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32)
z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32)
[xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij")
xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3)
grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
xyz_samples, grid_size, length = generate_dense_grid_points(
bbox_min=bbox_min,
bbox_max=bbox_max,
octree_resolution=octree_resolution,
indexing="ij"
)
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
# 2. latents to 3d volume
batch_logits = []
for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding",
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
disable=not enable_pbar):
chunk_queries = xyz[start: start + num_chunks, :]
chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1)
logits = geo_decoder(queries = chunk_queries, latents = latents)
chunk_queries = xyz_samples[start: start + num_chunks, :]
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
logits = geo_decoder(queries=chunk_queries, latents=latents)
batch_logits.append(logits)
grid_logits = torch.cat(batch_logits, dim = 1)
grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float()
grid_logits = torch.cat(batch_logits, dim=1)
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
return grid_logits
class FourierEmbedder(nn.Module):
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
each feature dimension of `x[..., i]` into:
@@ -552,11 +175,13 @@ class FourierEmbedder(nn.Module):
else:
return x
class CrossAttentionProcessor:
def __call__(self, attn, q, k, v):
out = comfy.ops.scaled_dot_product_attention(q, k, v)
return out
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
@@ -607,42 +232,39 @@ class MLP(nn.Module):
def forward(self, x):
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
class QKVMultiheadCrossAttention(nn.Module):
def __init__(
self,
*,
heads: int,
n_data = None,
width=None,
qk_norm=False,
norm_layer=ops.LayerNorm
):
super().__init__()
self.heads = heads
self.n_data = n_data
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
def forward(self, q, kv):
self.attn_processor = CrossAttentionProcessor()
def forward(self, q, kv):
_, n_ctx, _ = q.shape
bs, n_data, width = kv.shape
attn_ch = width // self.heads // 2
q = q.view(bs, n_ctx, self.heads, -1)
kv = kv.view(bs, n_data, self.heads, -1)
k, v = torch.split(kv, attn_ch, dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
out = F.scaled_dot_product_attention(q, k, v)
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
out = self.attn_processor(self, q, k, v)
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
return out
class MultiheadCrossAttention(nn.Module):
def __init__(
self,
@@ -684,6 +306,7 @@ class MultiheadCrossAttention(nn.Module):
x = self.c_proj(x)
return x
class ResidualCrossAttentionBlock(nn.Module):
def __init__(
self,
@@ -743,7 +366,7 @@ class QKVMultiheadAttention(nn.Module):
q = self.q_norm(q)
k = self.k_norm(k)
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
return out
@@ -760,7 +383,8 @@ class MultiheadAttention(nn.Module):
drop_path_rate: float = 0.0
):
super().__init__()
self.width = width
self.heads = heads
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
self.c_proj = ops.Linear(width, width)
self.attention = QKVMultiheadAttention(
@@ -867,7 +491,7 @@ class CrossAttentionDecoder(nn.Module):
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
if self.downsample_ratio != 1:
self.latents_proj = ops.Linear(width * downsample_ratio, width)
if not self.enable_ln_post:
if self.enable_ln_post == False:
qk_norm = False
self.cross_attn_decoder = ResidualCrossAttentionBlock(
width=width,
@@ -898,44 +522,28 @@ class CrossAttentionDecoder(nn.Module):
class ShapeVAE(nn.Module):
def __init__(
self,
*,
num_latents: int = 4096,
embed_dim: int = 64,
width: int = 1024,
heads: int = 16,
num_decoder_layers: int = 16,
num_encoder_layers: int = 8,
pc_size: int = 81920,
pc_sharpedge_size: int = 0,
point_feats: int = 4,
downsample_ratio: int = 20,
geo_decoder_downsample_ratio: int = 1,
geo_decoder_mlp_expand_ratio: int = 4,
geo_decoder_ln_post: bool = True,
num_freqs: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
drop_path_rate: float = 0.0,
include_pi: bool = False,
scale_factor: float = 1.0039506158752403,
label_type: str = "binary",
self,
*,
embed_dim: int,
width: int,
heads: int,
num_decoder_layers: int,
geo_decoder_downsample_ratio: int = 1,
geo_decoder_mlp_expand_ratio: int = 4,
geo_decoder_ln_post: bool = True,
num_freqs: int = 8,
include_pi: bool = True,
qkv_bias: bool = True,
qk_norm: bool = False,
label_type: str = "binary",
drop_path_rate: float = 0.0,
scale_factor: float = 1.0,
):
super().__init__()
self.geo_decoder_ln_post = geo_decoder_ln_post
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
self.encoder = PointCrossAttention(layers = num_encoder_layers,
num_latents = num_latents,
downsample_ratio = downsample_ratio,
heads = heads,
pc_size = pc_size,
width = width,
point_feats = point_feats,
fourier_embedder = self.fourier_embedder,
pc_sharpedge_size = pc_sharpedge_size)
self.post_kl = ops.Linear(embed_dim, width)
self.transformer = Transformer(
@@ -975,14 +583,5 @@ class ShapeVAE(nn.Module):
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
return grid_logits.movedim(-2, -1)
def encode(self, surface):
pc, feats = surface[:, :, :3], surface[:, :, 3:]
latents = self.encoder(pc, feats)
moments = self.pre_kl(latents)
posterior = DiagonalGaussianDistribution(moments, feature_dim = -1)
latents = posterior.sample()
return latents
def encode(self, x):
return None

View File

@@ -1,659 +0,0 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, operations, device, dtype):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out, device = device, dtype = dtype)
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type == "mps":
return F.gelu(gate.to(dtype = torch.float32)).to(dtype = gate.dtype)
return F.gelu(gate)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.gelu(hidden_states)
return hidden_states
class FeedForward(nn.Module):
def __init__(self, dim: int, dim_out = None, mult: int = 4,
dropout: float = 0.0, inner_dim = None, operations = None, device = None, dtype = None):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
act_fn = GELU(dim, inner_dim, operations = operations, device = device, dtype = dtype)
self.net = nn.ModuleList([])
self.net.append(act_fn)
self.net.append(nn.Dropout(dropout))
self.net.append(operations.Linear(inner_dim, dim_out, device = device, dtype = dtype))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class AddAuxLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, x, loss):
# do nothing in forward (no computation)
ctx.requires_aux_loss = loss.requires_grad
ctx.dtype = loss.dtype
return x
@staticmethod
def backward(ctx, grad_output):
# add the aux loss gradients
grad_loss = None
# put the aux grad the same as the main grad loss
# aux grad contributes equally
if ctx.requires_aux_loss:
grad_loss = torch.ones(1, dtype = ctx.dtype, device = grad_output.device)
return grad_output, grad_loss
class MoEGate(nn.Module):
def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01, device = None, dtype = None):
super().__init__()
self.top_k = num_experts_per_tok
self.n_routed_experts = num_experts
self.alpha = aux_loss_alpha
self.gating_dim = embed_dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), device = device, dtype = dtype))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# flatten hidden states
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
# get logits and pass it to softmax
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), bias = None)
scores = logits.softmax(dim = -1)
topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False)
if self.training and self.alpha > 0.0:
scores_for_aux = scores
# used bincount instead of one hot encoding
counts = torch.bincount(topk_idx.view(-1), minlength = self.n_routed_experts).float()
ce = counts / topk_idx.numel() # normalized expert usage
# mean expert score
Pi = scores_for_aux.mean(0)
# expert balance loss
aux_loss = (Pi * ce * self.n_routed_experts).sum() * self.alpha
else:
aux_loss = None
return topk_idx, topk_weight, aux_loss
class MoEBlock(nn.Module):
def __init__(self, dim, num_experts: int = 6, moe_top_k: int = 2, dropout: float = 0.0,
ff_inner_dim: int = None, operations = None, device = None, dtype = None):
super().__init__()
self.moe_top_k = moe_top_k
self.num_experts = num_experts
self.experts = nn.ModuleList([
FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
for _ in range(num_experts)
])
self.gate = MoEGate(dim, num_experts = num_experts, num_experts_per_tok = moe_top_k, device = device, dtype = dtype)
self.shared_experts = FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
def forward(self, hidden_states) -> torch.Tensor:
identity = hidden_states
orig_shape = hidden_states.shape
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim = 0)
y = torch.empty_like(hidden_states, dtype = hidden_states.dtype)
for i, expert in enumerate(self.experts):
tmp = expert(hidden_states[flat_topk_idx == i])
y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim = 1)
y = y.view(*orig_shape)
y = AddAuxLoss.apply(y, aux_loss)
else:
y = self.moe_infer(hidden_states, flat_expert_indices = flat_topk_idx,flat_expert_weights = topk_weight.view(-1, 1)).view(*orig_shape)
y = y + self.shared_experts(identity)
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
# no need for .numpy().cpu() here
tokens_per_expert = flat_expert_indices.bincount().cumsum(0)
token_idxs = idxs // self.moe_top_k
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
# use index_add_ with a 1-D index tensor directly avoids building a large [N, D] index map and extra memcopy required by scatter_reduce_
# + avoid dtype conversion
expert_cache.index_add_(0, exp_token_idx, expert_out)
return expert_cache
class Timesteps(nn.Module):
def __init__(self, num_channels: int, downscale_freq_shift: float = 0.0,
scale: float = 1.0, max_period: int = 10000):
super().__init__()
self.num_channels = num_channels
half_dim = num_channels // 2
# precompute the “inv_freq” vector once
exponent = -math.log(max_period) * torch.arange(
half_dim, dtype=torch.float32
) / (half_dim - downscale_freq_shift)
inv_freq = torch.exp(exponent)
# pad
if num_channels % 2 == 1:
# well pad a zero at the end of the cos-half
inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(1)])
# register to buffer so it moves with the device
self.register_buffer("inv_freq", inv_freq, persistent = False)
self.scale = scale
def forward(self, timesteps: torch.Tensor):
x = timesteps.float().unsqueeze(1) * self.inv_freq.to(timesteps.device).unsqueeze(0)
# fused CUDA kernels for sin and cos
sin_emb = x.sin()
cos_emb = x.cos()
emb = torch.cat([sin_emb, cos_emb], dim = 1)
# scale factor
if self.scale != 1.0:
emb = emb * self.scale
# If we padded inv_freq for odd, emb is already wide enough; otherwise:
if emb.shape[1] > self.num_channels:
emb = emb[:, :self.num_channels]
return emb
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size = 256, cond_proj_dim = None, operations = None, device = None, dtype = None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(hidden_size, frequency_embedding_size, bias=True, device = device, dtype = dtype),
nn.GELU(),
operations.Linear(frequency_embedding_size, hidden_size, bias=True, device = device, dtype = dtype),
)
self.frequency_embedding_size = frequency_embedding_size
if cond_proj_dim is not None:
self.cond_proj = operations.Linear(cond_proj_dim, frequency_embedding_size, bias=False, device = device, dtype = dtype)
self.time_embed = Timesteps(hidden_size)
def forward(self, timesteps, condition):
timestep_embed = self.time_embed(timesteps).type(self.mlp[0].weight.dtype)
if condition is not None:
cond_embed = self.cond_proj(condition)
timestep_embed = timestep_embed + cond_embed
time_conditioned = self.mlp(timestep_embed)
# for broadcasting with image tokens
return time_conditioned.unsqueeze(1)
class MLP(nn.Module):
def __init__(self, *, width: int, operations = None, device = None, dtype = None):
super().__init__()
self.width = width
self.fc1 = operations.Linear(width, width * 4, device = device, dtype = dtype)
self.fc2 = operations.Linear(width * 4, width, device = device, dtype = dtype)
self.gelu = nn.GELU()
def forward(self, x):
return self.fc2(self.gelu(self.fc1(x)))
class CrossAttention(nn.Module):
def __init__(
self,
qdim,
kdim,
num_heads,
qkv_bias=True,
qk_norm=False,
norm_layer=nn.LayerNorm,
use_fp16: bool = False,
operations = None,
dtype = None,
device = None,
**kwargs,
):
super().__init__()
self.qdim = qdim
self.kdim = kdim
self.num_heads = num_heads
self.head_dim = self.qdim // num_heads
self.scale = self.head_dim ** -0.5
self.to_q = operations.Linear(qdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
self.to_k = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
self.to_v = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
if use_fp16:
eps = 1.0 / 65504
else:
eps = 1e-6
if norm_layer == nn.LayerNorm:
norm_layer = operations.LayerNorm
else:
norm_layer = operations.RMSNorm
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
self.out_proj = operations.Linear(qdim, qdim, bias=True, device = device, dtype = dtype)
def forward(self, x, y):
b, s1, _ = x.shape
_, s2, _ = y.shape
y = y.to(next(self.to_k.parameters()).dtype)
q = self.to_q(x)
k = self.to_k(y)
v = self.to_v(y)
kv = torch.cat((k, v), dim=-1)
split_size = kv.shape[-1] // self.num_heads // 2
kv = kv.view(1, -1, self.num_heads, split_size * 2)
k, v = torch.split(kv, split_size, dim=-1)
q = q.view(b, s1, self.num_heads, self.head_dim)
k = k.view(b, s2, self.num_heads, self.head_dim)
v = v.reshape(b, s2, self.num_heads * self.head_dim)
q = self.q_norm(q)
k = self.k_norm(k)
x = optimized_attention(
q.reshape(b, s1, self.num_heads * self.head_dim),
k.reshape(b, s2, self.num_heads * self.head_dim),
v,
heads=self.num_heads,
)
out = self.out_proj(x)
return out
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads,
qkv_bias = True,
qk_norm = False,
norm_layer = nn.LayerNorm,
use_fp16: bool = False,
operations = None,
device = None,
dtype = None
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = self.dim // num_heads
self.scale = self.head_dim ** -0.5
self.to_q = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
self.to_k = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
self.to_v = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
if use_fp16:
eps = 1.0 / 65504
else:
eps = 1e-6
if norm_layer == nn.LayerNorm:
norm_layer = operations.LayerNorm
else:
norm_layer = operations.RMSNorm
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
self.out_proj = operations.Linear(dim, dim, device = device, dtype = dtype)
def forward(self, x):
B, N, _ = x.shape
query = self.to_q(x)
key = self.to_k(x)
value = self.to_v(x)
qkv_combined = torch.cat((query, key, value), dim=-1)
split_size = qkv_combined.shape[-1] // self.num_heads // 3
qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3)
query, key, value = torch.split(qkv, split_size, dim=-1)
query = query.reshape(B, N, self.num_heads, self.head_dim)
key = key.reshape(B, N, self.num_heads, self.head_dim)
value = value.reshape(B, N, self.num_heads * self.head_dim)
query = self.q_norm(query)
key = self.k_norm(key)
x = optimized_attention(
query.reshape(B, N, self.num_heads * self.head_dim),
key.reshape(B, N, self.num_heads * self.head_dim),
value,
heads=self.num_heads,
)
x = self.out_proj(x)
return x
class HunYuanDiTBlock(nn.Module):
def __init__(
self,
hidden_size,
c_emb_size,
num_heads,
text_states_dim=1024,
qk_norm=False,
norm_layer=nn.LayerNorm,
qk_norm_layer=True,
qkv_bias=True,
skip_connection=True,
timested_modulate=False,
use_moe: bool = False,
num_experts: int = 8,
moe_top_k: int = 2,
use_fp16: bool = False,
operations = None,
device = None, dtype = None
):
super().__init__()
# eps can't be 1e-6 in fp16 mode because of numerical stability issues
if use_fp16:
eps = 1.0 / 65504
else:
eps = 1e-6
self.norm1 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
norm_layer=qk_norm_layer, use_fp16 = use_fp16, device = device, dtype = dtype, operations = operations)
self.norm2 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
self.timested_modulate = timested_modulate
if self.timested_modulate:
self.default_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(c_emb_size, hidden_size, bias=True, device = device, dtype = dtype)
)
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias,
qk_norm=qk_norm, norm_layer=qk_norm_layer, use_fp16 = use_fp16,
device = device, dtype = dtype, operations = operations)
self.norm3 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
if skip_connection:
self.skip_norm = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, device = device, dtype = dtype)
else:
self.skip_linear = None
self.use_moe = use_moe
if self.use_moe:
self.moe = MoEBlock(
hidden_size,
num_experts = num_experts,
moe_top_k = moe_top_k,
dropout = 0.0,
ff_inner_dim = int(hidden_size * 4.0),
device = device, dtype = dtype,
operations = operations
)
else:
self.mlp = MLP(width=hidden_size, operations=operations, device = device, dtype = dtype)
def forward(self, hidden_states, conditioning=None, text_states=None, skip_tensor=None):
if self.skip_linear is not None:
combined = torch.cat([skip_tensor, hidden_states], dim=-1)
hidden_states = self.skip_linear(combined)
hidden_states = self.skip_norm(hidden_states)
# self attention
if self.timested_modulate:
modulation_shift = self.default_modulation(conditioning).unsqueeze(dim=1)
hidden_states = hidden_states + modulation_shift
self_attn_out = self.attn1(self.norm1(hidden_states))
hidden_states = hidden_states + self_attn_out
# cross attention
hidden_states = hidden_states + self.attn2(self.norm2(hidden_states), text_states)
# MLP Layer
mlp_input = self.norm3(hidden_states)
if self.use_moe:
hidden_states = hidden_states + self.moe(mlp_input)
else:
hidden_states = hidden_states + self.mlp(mlp_input)
return hidden_states
class FinalLayer(nn.Module):
def __init__(self, final_hidden_size, out_channels, operations, use_fp16: bool = False, device = None, dtype = None):
super().__init__()
if use_fp16:
eps = 1.0 / 65504
else:
eps = 1e-6
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
self.linear = operations.Linear(final_hidden_size, out_channels, bias = True, device = device, dtype = dtype)
def forward(self, x):
x = self.norm_final(x)
x = x[:, 1:]
x = self.linear(x)
return x
class HunYuanDiTPlain(nn.Module):
# init with the defaults values from https://huggingface.co/tencent/Hunyuan3D-2.1/blob/main/hunyuan3d-dit-v2-1/config.yaml
def __init__(
self,
in_channels: int = 64,
hidden_size: int = 2048,
context_dim: int = 1024,
depth: int = 21,
num_heads: int = 16,
qk_norm: bool = True,
qkv_bias: bool = False,
num_moe_layers: int = 6,
guidance_cond_proj_dim = 2048,
norm_type = 'layer',
num_experts: int = 8,
moe_top_k: int = 2,
use_fp16: bool = False,
dtype = None,
device = None,
operations = None,
**kwargs
):
self.dtype = dtype
super().__init__()
self.depth = depth
self.in_channels = in_channels
self.out_channels = in_channels
self.num_heads = num_heads
self.hidden_size = hidden_size
norm = operations.LayerNorm if norm_type == 'layer' else operations.RMSNorm
qk_norm = operations.RMSNorm
self.context_dim = context_dim
self.guidance_cond_proj_dim = guidance_cond_proj_dim
self.x_embedder = operations.Linear(in_channels, hidden_size, bias = True, device = device, dtype = dtype)
self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim = guidance_cond_proj_dim, device = device, dtype = dtype, operations = operations)
# HUnYuanDiT Blocks
self.blocks = nn.ModuleList([
HunYuanDiTBlock(hidden_size=hidden_size,
c_emb_size=hidden_size,
num_heads=num_heads,
text_states_dim=context_dim,
qk_norm=qk_norm,
norm_layer = norm,
qk_norm_layer = qk_norm,
skip_connection=layer > depth // 2,
qkv_bias=qkv_bias,
use_moe=True if depth - layer <= num_moe_layers else False,
num_experts=num_experts,
moe_top_k=moe_top_k,
use_fp16 = use_fp16,
device = device, dtype = dtype, operations = operations)
for layer in range(depth)
])
self.depth = depth
self.final_layer = FinalLayer(hidden_size, self.out_channels, use_fp16 = use_fp16, operations = operations, device = device, dtype = dtype)
def forward(self, x, t, context, transformer_options = {}, **kwargs):
x = x.movedim(-1, -2)
uncond_emb, cond_emb = context.chunk(2, dim = 0)
context = torch.cat([cond_emb, uncond_emb], dim = 0)
main_condition = context
t = 1.0 - t
time_embedded = self.t_embedder(t, condition = kwargs.get('guidance_cond'))
x = x.to(dtype = next(self.x_embedder.parameters()).dtype)
x_embedded = self.x_embedder(x)
combined = torch.cat([time_embedded, x_embedded], dim=1)
def block_wrap(args):
return block(
args["x"],
args["t"],
args["cond"],
skip_tensor=args.get("skip"),)
skip_stack = []
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for idx, block in enumerate(self.blocks):
if idx <= self.depth // 2:
skip_input = None
else:
skip_input = skip_stack.pop()
if ("block", idx) in blocks_replace:
combined = blocks_replace[("block", idx)](
{
"x": combined,
"t": time_embedded,
"cond": main_condition,
"skip": skip_input,
},
{"original_block": block_wrap},
)
else:
combined = block(combined, time_embedded, main_condition, skip_tensor=skip_input)
if idx < self.depth // 2:
skip_stack.append(combined)
output = self.final_layer(combined)
output = output.movedim(-2, -1) * (-1.0)
cond_emb, uncond_emb = output.chunk(2, dim = 0)
return torch.cat([uncond_emb, cond_emb])

View File

@@ -40,8 +40,6 @@ class HunyuanVideoParams:
patch_size: list
qkv_bias: bool
guidance_embed: bool
byt5: bool
meanflow: bool
class SelfAttentionRef(nn.Module):
@@ -163,30 +161,6 @@ class TokenRefiner(nn.Module):
x = self.individual_token_refiner(x, c, mask)
return x
class ByT5Mapper(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_res=False, dtype=None, device=None, operations=None):
super().__init__()
self.layernorm = operations.LayerNorm(in_dim, dtype=dtype, device=device)
self.fc1 = operations.Linear(in_dim, hidden_dim, dtype=dtype, device=device)
self.fc2 = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
self.fc3 = operations.Linear(out_dim, out_dim1, dtype=dtype, device=device)
self.use_res = use_res
self.act_fn = nn.GELU()
def forward(self, x):
if self.use_res:
res = x
x = self.layernorm(x)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
x2 = self.act_fn(x)
x2 = self.fc3(x2)
if self.use_res:
x2 = x2 + res
return x2
class HunyuanVideo(nn.Module):
"""
Transformer model for flow matching on sequences.
@@ -211,13 +185,9 @@ class HunyuanVideo(nn.Module):
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=len(self.patch_size) == 3, dtype=dtype, device=device, operations=operations)
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
if params.vec_in_dim is not None:
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
self.vector_in = None
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
)
@@ -245,23 +215,6 @@ class HunyuanVideo(nn.Module):
]
)
if params.byt5:
self.byt5_in = ByT5Mapper(
in_dim=1472,
out_dim=2048,
hidden_dim=2048,
out_dim1=self.hidden_size,
use_res=False,
dtype=dtype, device=device, operations=operations
)
else:
self.byt5_in = None
if params.meanflow:
self.time_r_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
self.time_r_in = None
if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
@@ -273,8 +226,7 @@ class HunyuanVideo(nn.Module):
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
y: Tensor = None,
txt_byt5=None,
y: Tensor,
guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
@@ -288,14 +240,6 @@ class HunyuanVideo(nn.Module):
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
if self.time_r_in is not None:
w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
if len(w) > 0:
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
vec = (vec + vec_r) / 2
if ref_latent is not None:
ref_latent_ids = self.img_ids(ref_latent)
ref_latent = self.img_in(ref_latent)
@@ -306,17 +250,13 @@ class HunyuanVideo(nn.Module):
if guiding_frame_index is not None:
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
if self.vector_in is not None:
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
else:
vec = torch.cat([(token_replace_vec).unsqueeze(1), (vec).unsqueeze(1)], dim=1)
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
modulation_dims_txt = [(0, None, 1)]
else:
if self.vector_in is not None:
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
modulation_dims = None
modulation_dims_txt = None
@@ -329,12 +269,6 @@ class HunyuanVideo(nn.Module):
txt = self.txt_in(txt, timesteps, txt_mask)
if self.byt5_in is not None and txt_byt5 is not None:
txt_byt5 = self.byt5_in(txt_byt5)
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt = torch.cat((txt, txt_byt5), dim=1)
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
@@ -394,16 +328,12 @@ class HunyuanVideo(nn.Module):
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
shape = initial_shape[-len(self.patch_size):]
shape = initial_shape[-3:]
for i in range(len(shape)):
shape[i] = shape[i] // self.patch_size[i]
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
if img.ndim == 8:
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
else:
img = img.permute(0, 3, 1, 4, 2, 5)
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3])
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
return img
def img_ids(self, x):
@@ -418,30 +348,16 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
def img_ids_2d(self, x):
bs, c, h, w = x.shape
patch_size = self.patch_size
h_len = ((h + (patch_size[0] // 2)) // patch_size[0])
w_len = ((w + (patch_size[1] // 2)) // patch_size[1])
img_ids = torch.zeros((h_len, w_len, 2), device=x.device, dtype=x.dtype)
img_ids[:, :, 0] = img_ids[:, :, 0] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
bs = x.shape[0]
if len(self.patch_size) == 3:
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
else:
img_ids = self.img_ids_2d(x)
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
return out

View File

@@ -1,136 +0,0 @@
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock
import comfy.ops
ops = comfy.ops.disable_weight_init
class PixelShuffle2D(nn.Module):
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
super().__init__()
self.conv = op(in_dim, out_dim >> 2, 3, 1, 1)
self.ratio = (in_dim << 2) // out_dim
def forward(self, x):
b, c, h, w = x.shape
h2, w2 = h >> 1, w >> 1
y = self.conv(x).view(b, -1, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, -1, h2, w2)
r = x.view(b, c, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, c << 2, h2, w2)
return y + r.view(b, y.shape[1], self.ratio, h2, w2).mean(2)
class PixelUnshuffle2D(nn.Module):
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
super().__init__()
self.conv = op(in_dim, out_dim << 2, 3, 1, 1)
self.scale = (out_dim << 2) // in_dim
def forward(self, x):
b, c, h, w = x.shape
h2, w2 = h << 1, w << 1
y = self.conv(x).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
r = x.repeat_interleave(self.scale, 1).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
return y + r
class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, downsample_match_channel=True, **_):
super().__init__()
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.conv_in = ops.Conv2d(in_channels, block_out_channels[0], 3, 1, 1)
self.down = nn.ModuleList()
ch = block_out_channels[0]
depth = (ffactor_spatial >> 1).bit_length()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=ops.Conv2d)
for j in range(num_res_blocks)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
stage.downsample = PixelShuffle2D(ch, nxt, ops.Conv2d)
ch = nxt
self.down.append(stage)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1)
def forward(self, x):
x = self.conv_in(x)
for stage in self.down:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'downsample'):
x = stage.downsample(x)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
b, c, h, w = x.shape
grp = c // (self.z_channels << 1)
skip = x.view(b, c // grp, grp, h, w).mean(2)
return self.conv_out(F.silu(self.norm_out(x))) + skip
class Decoder(nn.Module):
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
ffactor_spatial, upsample_match_channel=True, **_):
super().__init__()
block_out_channels = block_out_channels[::-1]
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
ch = block_out_channels[0]
self.conv_in = ops.Conv2d(z_channels, ch, 3, 1, 1)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=ops.Conv2d)
for j in range(num_res_blocks + 1)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
stage.upsample = PixelUnshuffle2D(ch, nxt, ops.Conv2d)
ch = nxt
self.up.append(stage)
self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1)
def forward(self, z):
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
for stage in self.up:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'upsample'):
x = stage.upsample(x)
return self.conv_out(F.silu(self.norm_out(x)))

View File

@@ -145,7 +145,7 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d):
dropout, temb_channels=512, conv_op=ops.Conv2d):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
@@ -183,7 +183,7 @@ class ResnetBlock(nn.Module):
stride=1,
padding=0)
def forward(self, x, temb=None):
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = self.swish(h)

View File

@@ -16,8 +16,6 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import comfy.ldm.hunyuan3dv2_1
import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch
import logging
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
@@ -1284,21 +1282,6 @@ class Hunyuan3Dv2(BaseModel):
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class Hunyuan3Dv2_1(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3dv2_1.hunyuandit.HunYuanDiTPlain)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
guidance = kwargs.get("guidance", 5.0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class HiDream(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
@@ -1408,27 +1391,3 @@ class QwenImage(BaseModel):
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
class HunyuanImage21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
conditioning_byt5small = kwargs.get("conditioning_byt5small", None)
if conditioning_byt5small is not None:
out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small)
guidance = kwargs.get("guidance", 6.0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out

View File

@@ -136,40 +136,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
dit_config = {}
in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)]
out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)]
dit_config["image_model"] = "hunyuan_video"
dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
dit_config["patch_size"] = list(in_w.shape[2:])
dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys):
dit_config["vec_in_dim"] = 768
else:
dit_config["vec_in_dim"] = None
if len(dit_config["patch_size"]) == 2:
dit_config["axes_dim"] = [64, 64]
else:
dit_config["axes_dim"] = [16, 56, 56]
if any(s.startswith('{}time_r_in.'.format(key_prefix)) for s in state_dict_keys):
dit_config["meanflow"] = True
else:
dit_config["meanflow"] = False
dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
dit_config["hidden_size"] = in_w.shape[0]
dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels
dit_config["patch_size"] = [1, 2, 2]
dit_config["out_channels"] = 16
dit_config["vec_in_dim"] = 768
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = in_w.shape[0] // 128
dit_config["num_heads"] = 24
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 256
dit_config["qkv_bias"] = True
if '{}byt5_in.fc1.weight'.format(key_prefix) in state_dict:
dit_config["byt5"] = True
else:
dit_config["byt5"] = False
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
dit_config["guidance_embed"] = len(guidance_keys) > 0
return dit_config
@@ -420,20 +400,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1
dit_config = {}
dit_config["image_model"] = "hunyuan3d2_1"
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
dit_config["context_dim"] = 1024
dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0]
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 16
dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}")
dit_config["qkv_bias"] = False
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
return dit_config
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
dit_config = {}
dit_config["image_model"] = "hidream"

View File

@@ -22,7 +22,6 @@ from enum import Enum
from comfy.cli_args import args, PerformanceFeature
import torch
import sys
import importlib
import platform
import weakref
import gc
@@ -290,24 +289,6 @@ def is_amd():
return True
return False
def amd_min_version(device=None, min_rdna_version=0):
if not is_amd():
return False
if is_device_cpu(device):
return False
arch = torch.cuda.get_device_properties(device).gcnArchName
if arch.startswith('gfx') and len(arch) == 7:
try:
cmp_rdna_version = int(arch[4]) + 2
except:
cmp_rdna_version = 0
if cmp_rdna_version >= min_rdna_version:
return True
return False
MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia():
MIN_WEIGHT_MEMORY_RATIO = 0.0
@@ -340,13 +321,12 @@ try:
logging.info("AMD arch: {}".format(arch))
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not.
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
# if torch_version_numeric >= (2, 8):
# if any((a in arch) for a in ["gfx1201"]):
# ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
# if torch_version_numeric >= (2, 8):
# if any((a in arch) for a in ["gfx1201"]):
# ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
SUPPORT_FP8_OPS = True
@@ -925,9 +905,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
# also a problem on RDNA4 except fp32 is also slow there.
# This is due to large bf16 convolutions being extremely slow.
if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
return d
return torch.float32

View File

@@ -433,9 +433,6 @@ class ModelPatcher:
def set_model_double_block_patch(self, patch):
self.set_model_patch(patch, "double_block")
def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input")
def add_object_patch(self, name, obj):
self.object_patches[name] = obj

View File

@@ -17,7 +17,6 @@ import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.hunyuan_video.vae
import yaml
import math
import os
@@ -49,7 +48,6 @@ import comfy.text_encoders.hidream
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.model_patcher
import comfy.lora
@@ -330,19 +328,6 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif "decoder.conv_in.weight" in sd:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
@@ -461,29 +446,17 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
# Hunyuan 3d v2 2.0 & 2.1
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
self.latent_dim = 1
def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):
batch, num_tokens, hidden_dim = shape
dtype_size = model_management.dtype_size(dtype)
total_mem = batch * num_tokens * hidden_dim * dtype_size * (1 + kv_cache_multiplier * num_layers)
return total_mem
# better memory estimations
self.memory_used_encode = lambda shape, dtype, num_layers = 8, kv_cache_multiplier = 0:\
estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
self.memory_used_decode = lambda shape, dtype, num_layers = 16, kv_cache_multiplier = 2: \
estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE()
ln_post = "geo_decoder.ln_post.weight" in sd
inner_size = sd["geo_decoder.output_proj.weight"].shape[1]
downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size
mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO
self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
@@ -800,7 +773,6 @@ class CLIPType(Enum):
ACE = 16
OMNIGEN2 = 17
QWEN_IMAGE = 18
HUNYUAN_IMAGE = 19
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -822,7 +794,6 @@ class TEModel(Enum):
GEMMA_2_2B = 9
QWEN25_3B = 10
QWEN25_7B = 11
BYT5_SMALL_GLYPH = 12
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -840,9 +811,6 @@ def detect_te_model(sd):
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
return TEModel.T5_XXL_OLD
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight']
if weight.shape[0] == 384:
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
return TEModel.GEMMA_2_2B
@@ -957,12 +925,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
elif te_model == TEModel.QWEN25_7B:
if clip_type == CLIPType.HUNYUAN_IMAGE:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
else:
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
@@ -1006,9 +970,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
elif clip_type == CLIPType.HUNYUAN_IMAGE:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer

View File

@@ -20,7 +20,6 @@ import comfy.text_encoders.wan
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
from . import supported_models_base
from . import latent_formats
@@ -1129,17 +1128,6 @@ class Hunyuan3Dv2(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return None
class Hunyuan3Dv2_1(Hunyuan3Dv2):
unet_config = {
"image_model": "hunyuan3d2_1",
}
latent_format = latent_formats.Hunyuan3Dv2_1
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Hunyuan3Dv2_1(self, device = device)
return out
class Hunyuan3Dv2mini(Hunyuan3Dv2):
unet_config = {
"image_model": "hunyuan3d2",
@@ -1296,31 +1284,7 @@ class QwenImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
class HunyuanImage21(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"vec_in_dim": None,
}
sampling_settings = {
"shift": 5.0,
}
latent_format = latent_formats.HunyuanImage21
memory_usage_factor = 7.7
supported_inference_dtypes = [torch.bfloat16, torch.float32]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanImage21(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models += [SVD_img2vid]

View File

@@ -1,22 +0,0 @@
{
"d_ff": 3584,
"d_kv": 64,
"d_model": 1472,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"dense_act_fn": "gelu_pytorch_tanh",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"num_decoder_layers": 4,
"num_heads": 6,
"num_layers": 12,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 1510
}

View File

@@ -1,127 +0,0 @@
{
"<extra_id_0>": 259,
"<extra_id_100>": 359,
"<extra_id_101>": 360,
"<extra_id_102>": 361,
"<extra_id_103>": 362,
"<extra_id_104>": 363,
"<extra_id_105>": 364,
"<extra_id_106>": 365,
"<extra_id_107>": 366,
"<extra_id_108>": 367,
"<extra_id_109>": 368,
"<extra_id_10>": 269,
"<extra_id_110>": 369,
"<extra_id_111>": 370,
"<extra_id_112>": 371,
"<extra_id_113>": 372,
"<extra_id_114>": 373,
"<extra_id_115>": 374,
"<extra_id_116>": 375,
"<extra_id_117>": 376,
"<extra_id_118>": 377,
"<extra_id_119>": 378,
"<extra_id_11>": 270,
"<extra_id_120>": 379,
"<extra_id_121>": 380,
"<extra_id_122>": 381,
"<extra_id_123>": 382,
"<extra_id_124>": 383,
"<extra_id_12>": 271,
"<extra_id_13>": 272,
"<extra_id_14>": 273,
"<extra_id_15>": 274,
"<extra_id_16>": 275,
"<extra_id_17>": 276,
"<extra_id_18>": 277,
"<extra_id_19>": 278,
"<extra_id_1>": 260,
"<extra_id_20>": 279,
"<extra_id_21>": 280,
"<extra_id_22>": 281,
"<extra_id_23>": 282,
"<extra_id_24>": 283,
"<extra_id_25>": 284,
"<extra_id_26>": 285,
"<extra_id_27>": 286,
"<extra_id_28>": 287,
"<extra_id_29>": 288,
"<extra_id_2>": 261,
"<extra_id_30>": 289,
"<extra_id_31>": 290,
"<extra_id_32>": 291,
"<extra_id_33>": 292,
"<extra_id_34>": 293,
"<extra_id_35>": 294,
"<extra_id_36>": 295,
"<extra_id_37>": 296,
"<extra_id_38>": 297,
"<extra_id_39>": 298,
"<extra_id_3>": 262,
"<extra_id_40>": 299,
"<extra_id_41>": 300,
"<extra_id_42>": 301,
"<extra_id_43>": 302,
"<extra_id_44>": 303,
"<extra_id_45>": 304,
"<extra_id_46>": 305,
"<extra_id_47>": 306,
"<extra_id_48>": 307,
"<extra_id_49>": 308,
"<extra_id_4>": 263,
"<extra_id_50>": 309,
"<extra_id_51>": 310,
"<extra_id_52>": 311,
"<extra_id_53>": 312,
"<extra_id_54>": 313,
"<extra_id_55>": 314,
"<extra_id_56>": 315,
"<extra_id_57>": 316,
"<extra_id_58>": 317,
"<extra_id_59>": 318,
"<extra_id_5>": 264,
"<extra_id_60>": 319,
"<extra_id_61>": 320,
"<extra_id_62>": 321,
"<extra_id_63>": 322,
"<extra_id_64>": 323,
"<extra_id_65>": 324,
"<extra_id_66>": 325,
"<extra_id_67>": 326,
"<extra_id_68>": 327,
"<extra_id_69>": 328,
"<extra_id_6>": 265,
"<extra_id_70>": 329,
"<extra_id_71>": 330,
"<extra_id_72>": 331,
"<extra_id_73>": 332,
"<extra_id_74>": 333,
"<extra_id_75>": 334,
"<extra_id_76>": 335,
"<extra_id_77>": 336,
"<extra_id_78>": 337,
"<extra_id_79>": 338,
"<extra_id_7>": 266,
"<extra_id_80>": 339,
"<extra_id_81>": 340,
"<extra_id_82>": 341,
"<extra_id_83>": 342,
"<extra_id_84>": 343,
"<extra_id_85>": 344,
"<extra_id_86>": 345,
"<extra_id_87>": 346,
"<extra_id_88>": 347,
"<extra_id_89>": 348,
"<extra_id_8>": 267,
"<extra_id_90>": 349,
"<extra_id_91>": 350,
"<extra_id_92>": 351,
"<extra_id_93>": 352,
"<extra_id_94>": 353,
"<extra_id_95>": 354,
"<extra_id_96>": 355,
"<extra_id_97>": 356,
"<extra_id_98>": 357,
"<extra_id_99>": 358,
"<extra_id_9>": 268
}

View File

@@ -1,150 +0,0 @@
{
"additional_special_tokens": [
"<extra_id_0>",
"<extra_id_1>",
"<extra_id_2>",
"<extra_id_3>",
"<extra_id_4>",
"<extra_id_5>",
"<extra_id_6>",
"<extra_id_7>",
"<extra_id_8>",
"<extra_id_9>",
"<extra_id_10>",
"<extra_id_11>",
"<extra_id_12>",
"<extra_id_13>",
"<extra_id_14>",
"<extra_id_15>",
"<extra_id_16>",
"<extra_id_17>",
"<extra_id_18>",
"<extra_id_19>",
"<extra_id_20>",
"<extra_id_21>",
"<extra_id_22>",
"<extra_id_23>",
"<extra_id_24>",
"<extra_id_25>",
"<extra_id_26>",
"<extra_id_27>",
"<extra_id_28>",
"<extra_id_29>",
"<extra_id_30>",
"<extra_id_31>",
"<extra_id_32>",
"<extra_id_33>",
"<extra_id_34>",
"<extra_id_35>",
"<extra_id_36>",
"<extra_id_37>",
"<extra_id_38>",
"<extra_id_39>",
"<extra_id_40>",
"<extra_id_41>",
"<extra_id_42>",
"<extra_id_43>",
"<extra_id_44>",
"<extra_id_45>",
"<extra_id_46>",
"<extra_id_47>",
"<extra_id_48>",
"<extra_id_49>",
"<extra_id_50>",
"<extra_id_51>",
"<extra_id_52>",
"<extra_id_53>",
"<extra_id_54>",
"<extra_id_55>",
"<extra_id_56>",
"<extra_id_57>",
"<extra_id_58>",
"<extra_id_59>",
"<extra_id_60>",
"<extra_id_61>",
"<extra_id_62>",
"<extra_id_63>",
"<extra_id_64>",
"<extra_id_65>",
"<extra_id_66>",
"<extra_id_67>",
"<extra_id_68>",
"<extra_id_69>",
"<extra_id_70>",
"<extra_id_71>",
"<extra_id_72>",
"<extra_id_73>",
"<extra_id_74>",
"<extra_id_75>",
"<extra_id_76>",
"<extra_id_77>",
"<extra_id_78>",
"<extra_id_79>",
"<extra_id_80>",
"<extra_id_81>",
"<extra_id_82>",
"<extra_id_83>",
"<extra_id_84>",
"<extra_id_85>",
"<extra_id_86>",
"<extra_id_87>",
"<extra_id_88>",
"<extra_id_89>",
"<extra_id_90>",
"<extra_id_91>",
"<extra_id_92>",
"<extra_id_93>",
"<extra_id_94>",
"<extra_id_95>",
"<extra_id_96>",
"<extra_id_97>",
"<extra_id_98>",
"<extra_id_99>",
"<extra_id_100>",
"<extra_id_101>",
"<extra_id_102>",
"<extra_id_103>",
"<extra_id_104>",
"<extra_id_105>",
"<extra_id_106>",
"<extra_id_107>",
"<extra_id_108>",
"<extra_id_109>",
"<extra_id_110>",
"<extra_id_111>",
"<extra_id_112>",
"<extra_id_113>",
"<extra_id_114>",
"<extra_id_115>",
"<extra_id_116>",
"<extra_id_117>",
"<extra_id_118>",
"<extra_id_119>",
"<extra_id_120>",
"<extra_id_121>",
"<extra_id_122>",
"<extra_id_123>",
"<extra_id_124>"
],
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<pad>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,100 +0,0 @@
from comfy import sd1_clip
import comfy.text_encoders.llama
from .qwen_image import QwenImageTokenizer, QwenImageTEModel
from transformers import ByT5Tokenizer
import os
import re
class ByT5SmallTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1472, embedding_key='byt5_small', tokenizer_class=ByT5Tokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data)
class HunyuanImageTokenizer(QwenImageTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
# self.llama_template_images = "{}"
self.byt5 = ByT5SmallTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
# ByT5 processing for HunyuanImage
text_prompt_texts = []
pattern_quote_single = r'\'(.*?)\''
pattern_quote_double = r'\"(.*?)\"'
pattern_quote_chinese_single = r'(.*?)'
pattern_quote_chinese_double = r'“(.*?)”'
matches_quote_single = re.findall(pattern_quote_single, text)
matches_quote_double = re.findall(pattern_quote_double, text)
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text)
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text)
text_prompt_texts.extend(matches_quote_single)
text_prompt_texts.extend(matches_quote_double)
text_prompt_texts.extend(matches_quote_chinese_single)
text_prompt_texts.extend(matches_quote_chinese_double)
if len(text_prompt_texts) > 0:
out['byt5'] = self.byt5.tokenize_with_weights(''.join(map(lambda a: 'Text "{}". '.format(a), text_prompt_texts)), return_word_ids, **kwargs)
return out
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
if llama_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class ByT5SmallModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_config_small_glyph.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
class HunyuanImageTEModel(QwenImageTEModel):
def __init__(self, byt5=True, device="cpu", dtype=None, model_options={}):
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
if byt5:
self.byt5_small = ByT5SmallModel(device=device, dtype=dtype, model_options=model_options)
else:
self.byt5_small = None
def encode_token_weights(self, token_weight_pairs):
cond, p, extra = super().encode_token_weights(token_weight_pairs)
if self.byt5_small is not None and "byt5" in token_weight_pairs:
out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"])
extra["conditioning_byt5small"] = out[0]
return cond, p, extra
def set_clip_options(self, options):
super().set_clip_options(options)
if self.byt5_small is not None:
self.byt5_small.set_clip_options(options)
def reset_clip_options(self):
super().reset_clip_options()
if self.byt5_small is not None:
self.byt5_small.reset_clip_options()
def load_sd(self, sd):
if "encoder.block.0.layer.0.SelfAttention.o.weight" in sd:
return self.byt5_small.load_sd(sd)
else:
return super().load_sd(sd)
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
class QwenImageTEModel_(HunyuanImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
return QwenImageTEModel_

View File

@@ -128,12 +128,11 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=N
def apply_rope(xq, xk, freqs_cis):
org_dtype = xq.dtype
cos = freqs_cis[0]
sin = freqs_cis[1]
q_embed = (xq * cos) + (rotate_half(xq) * sin)
k_embed = (xk * cos) + (rotate_half(xk) * sin)
return q_embed.to(org_dtype), k_embed.to(org_dtype)
return q_embed, k_embed
class Attention(nn.Module):

View File

@@ -1190,18 +1190,13 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
raise NotImplementedError
@classmethod
def validate_inputs(cls, **kwargs) -> bool | str:
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS.
If the function returns a string, it will be used as the validation error message for the node.
"""
def validate_inputs(cls, **kwargs) -> bool:
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
raise NotImplementedError
@classmethod
def fingerprint_inputs(cls, **kwargs) -> Any:
"""Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED.
If this function returns the same value as last run, the node will not be executed."""
"""Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED."""
raise NotImplementedError
@classmethod

View File

@@ -518,71 +518,6 @@ async def upload_audio_to_comfyapi(
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
if wav.dtype.is_floating_point:
return wav
elif wav.dtype == torch.int16:
return wav.float() / (2 ** 15)
elif wav.dtype == torch.int32:
return wav.float() / (2 ** 31)
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict:
"""
Decode any common audio container from bytes using PyAV and return
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
"""
with av.open(io.BytesIO(audio_bytes)) as af:
if not af.streams.audio:
raise ValueError("No audio stream found in response.")
stream = af.streams.audio[0]
in_sr = int(stream.codec_context.sample_rate)
out_sr = in_sr
frames: list[torch.Tensor] = []
n_channels = stream.channels or 1
for frame in af.decode(streams=stream.index):
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
buf = torch.from_numpy(arr)
if buf.ndim == 1:
buf = buf.unsqueeze(0) # [T] -> [1, T]
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
elif buf.shape[0] != n_channels:
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
frames.append(buf)
if not frames:
raise ValueError("Decoded zero audio frames.")
wav = torch.cat(frames, dim=1) # [C, T]
wav = f32_pcm(wav)
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO:
waveform = audio["waveform"].cpu()
output_buffer = io.BytesIO()
output_container = av.open(output_buffer, mode='w', format="mp3")
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
out_stream.bit_rate = 320000
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
frame.sample_rate = audio["sample_rate"]
frame.pts = 0
output_container.mux(out_stream.encode(frame))
output_container.mux(out_stream.encode(None))
output_container.close()
output_buffer.seek(0)
return output_buffer
def audio_to_base64_string(
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
) -> str:

View File

@@ -951,11 +951,7 @@ class MagicPrompt2(str, Enum):
class StyleType1(str, Enum):
AUTO = 'AUTO'
GENERAL = 'GENERAL'
REALISTIC = 'REALISTIC'
DESIGN = 'DESIGN'
FICTION = 'FICTION'
class ImagenImageGenerationInstance(BaseModel):
@@ -2680,7 +2676,7 @@ class ReleaseNote(BaseModel):
class RenderingSpeed(str, Enum):
DEFAULT = 'DEFAULT'
BALANCED = 'BALANCED'
TURBO = 'TURBO'
QUALITY = 'QUALITY'
@@ -4922,14 +4918,6 @@ class IdeogramV3EditRequest(BaseModel):
None,
description='A set of images to use as style references (maximum total size 10MB across all style references). The images should be in JPEG, PNG or WebP format.',
)
character_reference_images: Optional[List[str]] = Field(
None,
description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.'
)
character_reference_images_mask: Optional[List[str]] = Field(
None,
description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.'
)
class IdeogramV3Request(BaseModel):
@@ -4963,14 +4951,6 @@ class IdeogramV3Request(BaseModel):
style_type: Optional[StyleType1] = Field(
None, description='The type of style to apply'
)
character_reference_images: Optional[List[str]] = Field(
None,
description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.'
)
character_reference_images_mask: Optional[List[str]] = Field(
None,
description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.'
)
class ImagenGenerateImageResponse(BaseModel):

View File

@@ -125,25 +125,3 @@ class StabilityResultsGetResponse(BaseModel):
class StabilityAsyncResponse(BaseModel):
id: Optional[str] = Field(None)
class StabilityTextToAudioRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
duration: int = Field(190, ge=1, le=190)
seed: int = Field(0, ge=0, le=4294967294)
steps: int = Field(8, ge=4, le=8)
output_format: str = Field("wav")
class StabilityAudioToAudioRequest(StabilityTextToAudioRequest):
strength: float = Field(0.01, ge=0.01, le=1.0)
class StabilityAudioInpaintRequest(StabilityTextToAudioRequest):
mask_start: int = Field(30, ge=0, le=190)
mask_end: int = Field(190, ge=0, le=190)
class StabilityAudioResponse(BaseModel):
audio: Optional[str] = Field(None)

File diff suppressed because it is too large Load Diff

View File

@@ -255,7 +255,6 @@ class IdeogramV1(comfy_io.ComfyNode):
display_name="Ideogram V1",
category="api node/image/Ideogram",
description="Generates images using the Ideogram V1 model.",
is_api_node=True,
inputs=[
comfy_io.String.Input(
"prompt",
@@ -384,7 +383,6 @@ class IdeogramV2(comfy_io.ComfyNode):
display_name="Ideogram V2",
category="api node/image/Ideogram",
description="Generates images using the Ideogram V2 model.",
is_api_node=True,
inputs=[
comfy_io.String.Input(
"prompt",
@@ -554,7 +552,6 @@ class IdeogramV3(comfy_io.ComfyNode):
category="api node/image/Ideogram",
description="Generates images using the Ideogram V3 model. "
"Supports both regular image generation from text prompts and image editing with mask.",
is_api_node=True,
inputs=[
comfy_io.String.Input(
"prompt",
@@ -615,21 +612,11 @@ class IdeogramV3(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"rendering_speed",
options=["DEFAULT", "TURBO", "QUALITY"],
default="DEFAULT",
options=["BALANCED", "TURBO", "QUALITY"],
default="BALANCED",
tooltip="Controls the trade-off between generation speed and quality",
optional=True,
),
comfy_io.Image.Input(
"character_image",
tooltip="Image to use as character reference.",
optional=True,
),
comfy_io.Mask.Input(
"character_mask",
tooltip="Optional mask for character reference image.",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
@@ -652,46 +639,12 @@ class IdeogramV3(comfy_io.ComfyNode):
magic_prompt_option="AUTO",
seed=0,
num_images=1,
rendering_speed="DEFAULT",
character_image=None,
character_mask=None,
rendering_speed="BALANCED",
):
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
if rendering_speed == "BALANCED": # for backward compatibility
rendering_speed = "DEFAULT"
character_img_binary = None
character_mask_binary = None
if character_image is not None:
input_tensor = character_image.squeeze().cpu()
if character_mask is not None:
character_mask = resize_mask_to_image(character_mask, character_image, allow_gradient=False)
character_mask = 1.0 - character_mask
if character_mask.shape[1:] != character_image.shape[1:-1]:
raise Exception("Character mask and image must be the same size")
mask_np = (character_mask.squeeze().cpu().numpy() * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_np)
mask_byte_arr = BytesIO()
mask_img.save(mask_byte_arr, format="PNG")
mask_byte_arr.seek(0)
character_mask_binary = mask_byte_arr
character_mask_binary.name = "mask.png"
img_np = (input_tensor.numpy() * 255).astype(np.uint8)
img = Image.fromarray(img_np)
img_byte_arr = BytesIO()
img.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
character_img_binary = img_byte_arr
character_img_binary.name = "image.png"
elif character_mask is not None:
raise Exception("Character mask requires character image to be present")
# Check if both image and mask are provided for editing mode
if image is not None and mask is not None:
# Edit mode
@@ -740,15 +693,6 @@ class IdeogramV3(comfy_io.ComfyNode):
if num_images > 1:
edit_request.num_images = num_images
files = {
"image": img_binary,
"mask": mask_binary,
}
if character_img_binary:
files["character_reference_images"] = character_img_binary
if character_mask_binary:
files["character_mask_binary"] = character_mask_binary
# Execute the operation for edit mode
operation = SynchronousOperation(
endpoint=ApiEndpoint(
@@ -758,7 +702,10 @@ class IdeogramV3(comfy_io.ComfyNode):
response_model=IdeogramGenerateResponse,
),
request=edit_request,
files=files,
files={
"image": img_binary,
"mask": mask_binary,
},
content_type="multipart/form-data",
auth_kwargs=auth,
)
@@ -792,14 +739,6 @@ class IdeogramV3(comfy_io.ComfyNode):
if num_images > 1:
gen_request.num_images = num_images
files = {}
if character_img_binary:
files["character_reference_images"] = character_img_binary
if character_mask_binary:
files["character_mask_binary"] = character_mask_binary
if files:
gen_request.style_type = "AUTO"
# Execute the operation for generation mode
operation = SynchronousOperation(
endpoint=ApiEndpoint(
@@ -809,8 +748,6 @@ class IdeogramV3(comfy_io.ComfyNode):
response_model=IdeogramGenerateResponse,
),
request=gen_request,
files=files if files else None,
content_type="multipart/form-data",
auth_kwargs=auth,
)

View File

@@ -12,7 +12,6 @@ User Guides:
"""
from typing import Union, Optional, Any
from typing_extensions import override
from enum import Enum
import torch
@@ -47,9 +46,9 @@ from comfy_api_nodes.apinode_utils import (
validate_string,
download_url_to_image_tensor,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api_nodes.util.validation_utils import validate_image_dimensions, validate_image_aspect_ratio
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
@@ -86,11 +85,20 @@ class RunwayGen3aAspectRatio(str, Enum):
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the video URL from the task status response if it exists."""
if hasattr(response, "output") and len(response.output) > 0:
if response.output and len(response.output) > 0:
return response.output[0]
return None
# TODO: replace with updated image validation utils (upstream)
def validate_input_image(image: torch.Tensor) -> bool:
"""
Validate the input image is within the size limits for the Runway API.
See: https://docs.dev.runwayml.com/assets/inputs/#common-error-reasons
"""
return image.shape[2] < 8000 and image.shape[1] < 8000
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
@@ -126,438 +134,458 @@ def extract_progress_from_task_status(
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the image URL from the task status response if it exists."""
if hasattr(response, "output") and len(response.output) > 0:
if response.output and len(response.output) > 0:
return response.output[0]
return None
async def get_response(
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None, estimated_duration: Optional[int] = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=estimated_duration,
node_id=node_id,
)
class RunwayVideoGenNode(ComfyNodeABC):
"""Runway Video Node Base."""
RETURN_TYPES = ("VIDEO",)
FUNCTION = "api_call"
CATEGORY = "api node/video/Runway"
API_NODE = True
async def generate_video(
request: RunwayImageToVideoRequest,
auth_kwargs: dict[str, str],
node_id: Optional[str] = None,
estimated_duration: Optional[int] = None,
) -> VideoFromFile:
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=RunwayImageToVideoRequest,
response_model=RunwayImageToVideoResponse,
),
request=request,
auth_kwargs=auth_kwargs,
)
def validate_task_created(self, response: RunwayImageToVideoResponse) -> bool:
"""
Validate the task creation response from the Runway API matches
expected format.
"""
if not bool(response.id):
raise RunwayApiError("Invalid initial response from Runway API.")
return True
initial_response = await initial_operation.execute()
def validate_response(self, response: RunwayImageToVideoResponse) -> bool:
"""
Validate the successful task status response from the Runway API
matches expected format.
"""
if not response.output or len(response.output) == 0:
raise RunwayApiError(
"Runway task succeeded but no video data found in response."
)
return True
final_response = await get_response(initial_response.id, auth_kwargs, node_id, estimated_duration)
if not final_response.output:
raise RunwayApiError("Runway task succeeded but no video data found in response.")
video_url = get_video_url_from_task_status(final_response)
return await download_url_to_video_output(video_url)
class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="RunwayImageToVideoNodeGen3a",
display_name="Runway Image to Video (Gen3a Turbo)",
category="api node/video/Runway",
description="Generate a video from a single starting frame using Gen3a Turbo model. "
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the generation",
),
comfy_io.Image.Input(
"start_frame",
tooltip="Start frame to be used for the video",
),
comfy_io.Combo.Input(
"duration",
options=[model.value for model in Duration],
),
comfy_io.Combo.Input(
"ratio",
options=[model.value for model in RunwayGen3aAspectRatio],
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967295,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Random seed for generation",
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> RunwayImageToVideoResponse:
"""Poll the task status until it is finished then get the response."""
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
node_id=node_id,
)
async def generate_video(
self,
request: RunwayImageToVideoRequest,
auth_kwargs: dict[str, str],
node_id: Optional[str] = None,
) -> tuple[VideoFromFile]:
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=RunwayImageToVideoRequest,
response_model=RunwayImageToVideoResponse,
),
request=request,
auth_kwargs=auth_kwargs,
)
initial_response = await initial_operation.execute()
self.validate_task_created(initial_response)
task_id = initial_response.id
final_response = await self.get_response(task_id, auth_kwargs, node_id)
self.validate_response(final_response)
video_url = get_video_url_from_task_status(final_response)
return (await download_url_to_video_output(video_url),)
class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
"""Runway Image to Video Node using Gen3a Turbo model."""
DESCRIPTION = "Generate a video from a single starting frame using Gen3a Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo."
@classmethod
async def execute(
cls,
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
),
"start_frame": (
IO.IMAGE,
{"tooltip": "Start frame to be used for the video"},
),
"duration": model_field_to_node_input(
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayImageToVideoRequest,
"ratio",
enum_type=RunwayGen3aAspectRatio,
),
"seed": model_field_to_node_input(
IO.INT,
RunwayImageToVideoRequest,
"seed",
control_after_generate=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
prompt: str,
start_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
) -> comfy_io.NodeOutput:
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Validate inputs
validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
validate_input_image(start_frame)
# Upload image
download_urls = await upload_images_to_comfyapi(
start_frame,
max_images=1,
mime_type="image/png",
auth_kwargs=auth_kwargs,
auth_kwargs=kwargs,
)
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return comfy_io.NodeOutput(
await generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen3a_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
),
return await self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen3a_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
)
),
auth_kwargs=kwargs,
node_id=unique_id,
)
class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode):
class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
"""Runway Image to Video Node using Gen4 Turbo model."""
DESCRIPTION = "Generate a video from a single starting frame using Gen4 Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video."
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="RunwayImageToVideoNodeGen4",
display_name="Runway Image to Video (Gen4 Turbo)",
category="api node/video/Runway",
description="Generate a video from a single starting frame using Gen4 Turbo model. "
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the generation",
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
),
comfy_io.Image.Input(
"start_frame",
tooltip="Start frame to be used for the video",
"start_frame": (
IO.IMAGE,
{"tooltip": "Start frame to be used for the video"},
),
comfy_io.Combo.Input(
"duration",
options=[model.value for model in Duration],
"duration": model_field_to_node_input(
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
),
comfy_io.Combo.Input(
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayImageToVideoRequest,
"ratio",
options=[model.value for model in RunwayGen4TurboAspectRatio],
enum_type=RunwayGen4TurboAspectRatio,
),
comfy_io.Int.Input(
"seed": model_field_to_node_input(
IO.INT,
RunwayImageToVideoRequest,
"seed",
default=0,
min=0,
max=4294967295,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Random seed for generation",
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@classmethod
async def execute(
cls,
async def api_call(
self,
prompt: str,
start_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
) -> comfy_io.NodeOutput:
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Validate inputs
validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
validate_input_image(start_frame)
# Upload image
download_urls = await upload_images_to_comfyapi(
start_frame,
max_images=1,
mime_type="image/png",
auth_kwargs=auth_kwargs,
auth_kwargs=kwargs,
)
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return comfy_io.NodeOutput(
await generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen4_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
),
return await self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen4_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
)
),
auth_kwargs=kwargs,
node_id=unique_id,
)
class RunwayFirstLastFrameNode(comfy_io.ComfyNode):
class RunwayFirstLastFrameNode(RunwayVideoGenNode):
"""Runway First-Last Frame Node."""
DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3."
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> RunwayImageToVideoResponse:
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
node_id=node_id,
)
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="RunwayFirstLastFrameNode",
display_name="Runway First-Last-Frame to Video",
category="api node/video/Runway",
description="Upload first and last keyframes, draft a prompt, and generate a video. "
"More complex transitions, such as cases where the Last frame is completely different "
"from the First frame, may benefit from the longer 10s duration. "
"This would give the generation more time to smoothly transition between the two inputs. "
"Before diving in, review these best practices to ensure that your input selections "
"will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the generation",
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
),
comfy_io.Image.Input(
"start_frame",
tooltip="Start frame to be used for the video",
"start_frame": (
IO.IMAGE,
{"tooltip": "Start frame to be used for the video"},
),
comfy_io.Image.Input(
"end_frame",
tooltip="End frame to be used for the video. Supported for gen3a_turbo only.",
"end_frame": (
IO.IMAGE,
{
"tooltip": "End frame to be used for the video. Supported for gen3a_turbo only."
},
),
comfy_io.Combo.Input(
"duration",
options=[model.value for model in Duration],
"duration": model_field_to_node_input(
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
),
comfy_io.Combo.Input(
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayImageToVideoRequest,
"ratio",
options=[model.value for model in RunwayGen3aAspectRatio],
enum_type=RunwayGen3aAspectRatio,
),
comfy_io.Int.Input(
"seed": model_field_to_node_input(
IO.INT,
RunwayImageToVideoRequest,
"seed",
default=0,
min=0,
max=4294967295,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Random seed for generation",
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"unique_id": "UNIQUE_ID",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
@classmethod
async def execute(
cls,
async def api_call(
self,
prompt: str,
start_frame: torch.Tensor,
end_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
) -> comfy_io.NodeOutput:
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Validate inputs
validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
validate_input_image(start_frame)
validate_input_image(end_frame)
# Upload images
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
download_urls = await upload_images_to_comfyapi(
stacked_input_images,
max_images=2,
mime_type="image/png",
auth_kwargs=auth_kwargs,
auth_kwargs=kwargs,
)
if len(download_urls) != 2:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return comfy_io.NodeOutput(
await generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen3a_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
),
RunwayPromptImageDetailedObject(
uri=str(download_urls[1]), position="last"
),
]
),
return await self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen3a_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
),
RunwayPromptImageDetailedObject(
uri=str(download_urls[1]), position="last"
),
]
),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
)
),
auth_kwargs=kwargs,
node_id=unique_id,
)
class RunwayTextToImageNode(comfy_io.ComfyNode):
class RunwayTextToImageNode(ComfyNodeABC):
"""Runway Text to Image Node."""
RETURN_TYPES = ("IMAGE",)
FUNCTION = "api_call"
CATEGORY = "api node/image/Runway"
API_NODE = True
DESCRIPTION = "Generate an image from a text prompt using Runway's Gen 4 model. You can also include reference images to guide the generation."
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="RunwayTextToImageNode",
display_name="Runway Text to Image",
category="api node/image/Runway",
description="Generate an image from a text prompt using Runway's Gen 4 model. "
"You can also include reference image to guide the generation.",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the generation",
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayTextToImageRequest, "promptText", multiline=True
),
comfy_io.Combo.Input(
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayTextToImageRequest,
"ratio",
options=[model.value for model in RunwayTextToImageAspectRatioEnum],
enum_type=RunwayTextToImageAspectRatioEnum,
),
comfy_io.Image.Input(
"reference_image",
tooltip="Optional reference image to guide the generation",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
},
"optional": {
"reference_image": (
IO.IMAGE,
{"tooltip": "Optional reference image to guide the generation"},
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
def validate_task_created(self, response: RunwayTextToImageResponse) -> bool:
"""
Validate the task creation response from the Runway API matches
expected format.
"""
if not bool(response.id):
raise RunwayApiError("Invalid initial response from Runway API.")
return True
def validate_response(self, response: TaskStatusResponse) -> bool:
"""
Validate the successful task status response from the Runway API
matches expected format.
"""
if not response.output or len(response.output) == 0:
raise RunwayApiError(
"Runway task succeeded but no image data found in response."
)
return True
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
node_id=node_id,
)
@classmethod
async def execute(
cls,
async def api_call(
self,
prompt: str,
ratio: str,
reference_image: Optional[torch.Tensor] = None,
) -> comfy_io.NodeOutput:
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[torch.Tensor]:
# Validate inputs
validate_string(prompt, min_length=1)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# Prepare reference images if provided
reference_images = None
if reference_image is not None:
validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
validate_input_image(reference_image)
download_urls = await upload_images_to_comfyapi(
reference_image,
max_images=1,
mime_type="image/png",
auth_kwargs=auth_kwargs,
auth_kwargs=kwargs,
)
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload reference image to comfy api.")
reference_images = [ReferenceImage(uri=str(download_urls[0]))]
# Create request
request = RunwayTextToImageRequest(
promptText=prompt,
model=Model4.gen4_image,
@@ -565,6 +593,7 @@ class RunwayTextToImageNode(comfy_io.ComfyNode):
referenceImages=reference_images,
)
# Execute initial request
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_TEXT_TO_IMAGE,
@@ -573,33 +602,34 @@ class RunwayTextToImageNode(comfy_io.ComfyNode):
response_model=RunwayTextToImageResponse,
),
request=request,
auth_kwargs=auth_kwargs,
auth_kwargs=kwargs,
)
initial_response = await initial_operation.execute()
self.validate_task_created(initial_response)
task_id = initial_response.id
# Poll for completion
final_response = await get_response(
initial_response.id,
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
if not final_response.output:
raise RunwayApiError("Runway task succeeded but no image data found in response.")
self.validate_response(final_response)
return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
# Download and return image
image_url = get_image_url_from_task_status(final_response)
return (await download_url_to_image_tensor(image_url),)
class RunwayExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
RunwayFirstLastFrameNode,
RunwayImageToVideoNodeGen3a,
RunwayImageToVideoNodeGen4,
RunwayTextToImageNode,
]
NODE_CLASS_MAPPINGS = {
"RunwayFirstLastFrameNode": RunwayFirstLastFrameNode,
"RunwayImageToVideoNodeGen3a": RunwayImageToVideoNodeGen3a,
"RunwayImageToVideoNodeGen4": RunwayImageToVideoNodeGen4,
"RunwayTextToImageNode": RunwayTextToImageNode,
}
async def comfy_entrypoint() -> RunwayExtension:
return RunwayExtension()
NODE_DISPLAY_NAME_MAPPINGS = {
"RunwayFirstLastFrameNode": "Runway First-Last-Frame to Video",
"RunwayImageToVideoNodeGen3a": "Runway Image to Video (Gen3a Turbo)",
"RunwayImageToVideoNodeGen4": "Runway Image to Video (Gen4 Turbo)",
"RunwayTextToImageNode": "Runway Text to Image",
}

View File

@@ -2,7 +2,7 @@ from inspect import cleandoc
from typing import Optional
from typing_extensions import override
from comfy_api.latest import ComfyExtension, Input, io as comfy_io
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api_nodes.apis.stability_api import (
StabilityUpscaleConservativeRequest,
StabilityUpscaleCreativeRequest,
@@ -15,10 +15,6 @@ from comfy_api_nodes.apis.stability_api import (
Stability_SD3_5_Model,
Stability_SD3_5_GenerationMode,
get_stability_style_presets,
StabilityTextToAudioRequest,
StabilityAudioToAudioRequest,
StabilityAudioInpaintRequest,
StabilityAudioResponse,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
@@ -31,10 +27,7 @@ from comfy_api_nodes.apinode_utils import (
bytesio_to_image_tensor,
tensor_to_bytesio,
validate_string,
audio_bytes_to_audio_input,
audio_input_to_mp3,
)
from comfy_api_nodes.util.validation_utils import validate_audio_duration
import torch
import base64
@@ -656,306 +649,6 @@ class StabilityUpscaleFastNode(comfy_io.ComfyNode):
return comfy_io.NodeOutput(returned_image)
class StabilityTextToAudio(comfy_io.ComfyNode):
"""Generates high-quality music and sound effects from text descriptions."""
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="StabilityTextToAudio",
display_name="Stability AI Text To Audio",
category="api node/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Combo.Input(
"model",
options=["stable-audio-2.5"],
),
comfy_io.String.Input("prompt", multiline=True, default=""),
comfy_io.Int.Input(
"duration",
default=190,
min=1,
max=190,
step=1,
tooltip="Controls the duration in seconds of the generated audio.",
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for generation.",
optional=True,
),
comfy_io.Int.Input(
"steps",
default=8,
min=4,
max=8,
step=1,
tooltip="Controls the number of sampling steps.",
optional=True,
),
],
outputs=[
comfy_io.Audio.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> comfy_io.NodeOutput:
validate_string(prompt, max_length=10000)
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio",
method=HttpMethod.POST,
request_model=StabilityTextToAudioRequest,
response_model=StabilityAudioResponse,
),
request=payload,
content_type="multipart/form-data",
auth_kwargs= {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
response_api = await operation.execute()
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
class StabilityAudioToAudio(comfy_io.ComfyNode):
"""Transforms existing audio samples into new high-quality compositions using text instructions."""
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="StabilityAudioToAudio",
display_name="Stability AI Audio To Audio",
category="api node/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Combo.Input(
"model",
options=["stable-audio-2.5"],
),
comfy_io.String.Input("prompt", multiline=True, default=""),
comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
comfy_io.Int.Input(
"duration",
default=190,
min=1,
max=190,
step=1,
tooltip="Controls the duration in seconds of the generated audio.",
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for generation.",
optional=True,
),
comfy_io.Int.Input(
"steps",
default=8,
min=4,
max=8,
step=1,
tooltip="Controls the number of sampling steps.",
optional=True,
),
comfy_io.Float.Input(
"strength",
default=1,
min=0.01,
max=1.0,
step=0.01,
display_mode=comfy_io.NumberDisplay.slider,
tooltip="Parameter controls how much influence the audio parameter has on the generated audio.",
optional=True,
),
],
outputs=[
comfy_io.Audio.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls, model: str, prompt: str, audio: Input.Audio, duration: int, seed: int, steps: int, strength: float
) -> comfy_io.NodeOutput:
validate_string(prompt, max_length=10000)
validate_audio_duration(audio, 6, 190)
payload = StabilityAudioToAudioRequest(
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio",
method=HttpMethod.POST,
request_model=StabilityAudioToAudioRequest,
response_model=StabilityAudioResponse,
),
request=payload,
content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)},
auth_kwargs= {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
response_api = await operation.execute()
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
class StabilityAudioInpaint(comfy_io.ComfyNode):
"""Transforms part of existing audio sample using text instructions."""
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="StabilityAudioInpaint",
display_name="Stability AI Audio Inpaint",
category="api node/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Combo.Input(
"model",
options=["stable-audio-2.5"],
),
comfy_io.String.Input("prompt", multiline=True, default=""),
comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
comfy_io.Int.Input(
"duration",
default=190,
min=1,
max=190,
step=1,
tooltip="Controls the duration in seconds of the generated audio.",
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for generation.",
optional=True,
),
comfy_io.Int.Input(
"steps",
default=8,
min=4,
max=8,
step=1,
tooltip="Controls the number of sampling steps.",
optional=True,
),
comfy_io.Int.Input(
"mask_start",
default=30,
min=0,
max=190,
step=1,
optional=True,
),
comfy_io.Int.Input(
"mask_end",
default=190,
min=0,
max=190,
step=1,
optional=True,
),
],
outputs=[
comfy_io.Audio.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
audio: Input.Audio,
duration: int,
seed: int,
steps: int,
mask_start: int,
mask_end: int,
) -> comfy_io.NodeOutput:
validate_string(prompt, max_length=10000)
if mask_end <= mask_start:
raise ValueError(f"Value of mask_end({mask_end}) should be greater then mask_start({mask_start})")
validate_audio_duration(audio, 6, 190)
payload = StabilityAudioInpaintRequest(
prompt=prompt,
model=model,
duration=duration,
seed=seed,
steps=steps,
mask_start=mask_start,
mask_end=mask_end,
)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint",
method=HttpMethod.POST,
request_model=StabilityAudioInpaintRequest,
response_model=StabilityAudioResponse,
),
request=payload,
content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)},
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
response_api = await operation.execute()
if not response_api.audio:
raise ValueError("No audio file was received in response.")
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
class StabilityExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
@@ -965,9 +658,6 @@ class StabilityExtension(ComfyExtension):
StabilityUpscaleConservativeNode,
StabilityUpscaleCreativeNode,
StabilityUpscaleFastNode,
StabilityTextToAudio,
StabilityAudioToAudio,
StabilityAudioInpaint,
]

View File

@@ -2,7 +2,7 @@ import logging
from typing import Optional
import torch
from comfy_api.latest import Input
from comfy_api.input.video_types import VideoInput
def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
@@ -101,7 +101,7 @@ def validate_aspect_ratio_closeness(
def validate_video_dimensions(
video: Input.Video,
video: VideoInput,
min_width: Optional[int] = None,
max_width: Optional[int] = None,
min_height: Optional[int] = None,
@@ -126,7 +126,7 @@ def validate_video_dimensions(
def validate_video_duration(
video: Input.Video,
video: VideoInput,
min_duration: Optional[float] = None,
max_duration: Optional[float] = None,
):
@@ -151,17 +151,3 @@ def get_number_of_images(images):
if isinstance(images, torch.Tensor):
return images.shape[0] if images.ndim >= 4 else 1
return len(images)
def validate_audio_duration(
audio: Input.Audio,
min_duration: Optional[float] = None,
max_duration: Optional[float] = None,
) -> None:
sr = int(audio["sample_rate"])
dur = int(audio["waveform"].shape[-1]) / sr
eps = 1.0 / sr
if min_duration is not None and dur + eps < min_duration:
raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s")
if max_duration is not None and dur - eps > max_duration:
raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s")

View File

@@ -1,10 +1,6 @@
#from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
import numpy as np
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def loglinear_interp(t_steps, num_steps):
"""
@@ -23,30 +19,25 @@ NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.694615152
"SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582],
"SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]}
class AlignYourStepsScheduler(io.ComfyNode):
class AlignYourStepsScheduler:
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="AlignYourStepsScheduler",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),
io.Int.Input("steps", default=10, min=1, max=10000),
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[io.Sigmas.Output()],
)
def INPUT_TYPES(s):
return {"required":
{"model_type": (["SD1", "SDXL", "SVD"], ),
"steps": ("INT", {"default": 10, "min": 1, "max": 10000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, model_type, steps, denoise):
# Deprecated: use the V3 schema's `execute` method instead of this.
return AlignYourStepsScheduler().execute(model_type, steps, denoise).result
@classmethod
def execute(cls, model_type, steps, denoise) -> io.NodeOutput:
total_steps = steps
if denoise < 1.0:
if denoise <= 0.0:
return io.NodeOutput(torch.FloatTensor([]))
return (torch.FloatTensor([]),)
total_steps = round(steps * denoise)
sigmas = NOISE_LEVELS[model_type][:]
@@ -55,15 +46,8 @@ class AlignYourStepsScheduler(io.ComfyNode):
sigmas = sigmas[-(total_steps + 1):]
sigmas[-1] = 0
return io.NodeOutput(torch.FloatTensor(sigmas))
return (torch.FloatTensor(sigmas), )
class AlignYourStepsExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
AlignYourStepsScheduler,
]
async def comfy_entrypoint() -> AlignYourStepsExtension:
return AlignYourStepsExtension()
NODE_CLASS_MAPPINGS = {
"AlignYourStepsScheduler": AlignYourStepsScheduler,
}

View File

@@ -162,12 +162,7 @@ def easycache_sample_wrapper(executor, *args, **kwargs):
logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}")
logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}")
total_steps = len(args[3])-1
# catch division by zero for log statement; sucks to crash after all sampling is done
try:
speedup = total_steps/(total_steps-easycache.total_steps_skipped)
except ZeroDivisionError:
speedup = 1.0
logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({speedup:.2f}x speedup).")
logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({total_steps/(total_steps-easycache.total_steps_skipped):.2f}x speedup).")
easycache.reset()
guider.model_options = orig_model_options

View File

@@ -105,7 +105,7 @@ class FluxKontextMultiReferenceLatentMethod:
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"reference_latents_method": (("offset", "index", "uxo/uno"), ),
"reference_latents_method": (("offset", "index", "uso"), ),
}}
RETURN_TYPES = ("CONDITIONING",)
@@ -115,8 +115,6 @@ class FluxKontextMultiReferenceLatentMethod:
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, reference_latents_method):
if "uxo" in reference_latents_method or "uso" in reference_latents_method:
reference_latents_method = "uxo"
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
return (c, )

View File

@@ -113,20 +113,6 @@ class HunyuanImageToVideo:
out_latent["samples"] = latent
return (positive, out_latent)
class EmptyHunyuanImageLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent"
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, )
NODE_CLASS_MAPPINGS = {
@@ -134,5 +120,4 @@ NODE_CLASS_MAPPINGS = {
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
"HunyuanImageToVideo": HunyuanImageToVideo,
"EmptyHunyuanImageLatent": EmptyHunyuanImageLatent,
}

View File

@@ -8,16 +8,13 @@ import folder_paths
import comfy.model_management
from comfy.cli_args import args
class EmptyLatentHunyuan3Dv2:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}
}
return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
@@ -27,6 +24,7 @@ class EmptyLatentHunyuan3Dv2:
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
return ({"samples": latent, "type": "hunyuan3dv2"}, )
class Hunyuan3Dv2Conditioning:
@classmethod
def INPUT_TYPES(s):
@@ -83,6 +81,7 @@ class VOXEL:
def __init__(self, data):
self.data = data
class VAEDecodeHunyuan3D:
@classmethod
def INPUT_TYPES(s):
@@ -100,6 +99,7 @@ class VAEDecodeHunyuan3D:
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
return (voxels, )
def voxel_to_mesh(voxels, threshold=0.5, device=None):
if device is None:
device = torch.device("cpu")
@@ -230,9 +230,13 @@ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]
], device=device)
pos = cell_positions.unsqueeze(1) + corner_offsets.unsqueeze(0)
z_idx, y_idx, x_idx = pos.unbind(-1)
corner_values = padded[z_idx, y_idx, x_idx]
corner_values = torch.zeros((cell_positions.shape[0], 8), device=device)
for c, (dz, dy, dx) in enumerate(corner_offsets):
corner_values[:, c] = padded[
cell_positions[:, 0] + dz,
cell_positions[:, 1] + dy,
cell_positions[:, 2] + dx
]
corner_signs = corner_values > threshold
has_inside = torch.any(corner_signs, dim=1)

View File

@@ -625,37 +625,6 @@ class ImageFlip:
return (image,)
class ImageScaleToMaxDimension:
upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"]
@classmethod
def INPUT_TYPES(s):
return {"required": {"image": ("IMAGE",),
"upscale_method": (s.upscale_methods,),
"largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"
CATEGORY = "image/upscaling"
def upscale(self, image, upscale_method, largest_size):
height = image.shape[1]
width = image.shape[2]
if height > width:
width = round((width / height) * largest_size)
height = largest_size
elif width > height:
height = round((height / width) * largest_size)
width = largest_size
else:
height = largest_size
width = largest_size
samples = image.movedim(-1, 1)
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = s.movedim(1, -1)
return (s,)
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
@@ -670,5 +639,4 @@ NODE_CLASS_MAPPINGS = {
"GetImageSize": GetImageSize,
"ImageRotate": ImageRotate,
"ImageFlip": ImageFlip,
"ImageScaleToMaxDimension": ImageScaleToMaxDimension,
}

View File

@@ -1,5 +1,4 @@
import torch
from torch import nn
import folder_paths
import comfy.utils
import comfy.ops
@@ -59,136 +58,6 @@ class QwenImageBlockWiseControlNet(torch.nn.Module):
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
class SigLIPMultiFeatProjModel(torch.nn.Module):
"""
SigLIP Multi-Feature Projection Model for processing style features from different layers
and projecting them into a unified hidden space.
Args:
siglip_token_nums (int): Number of SigLIP tokens, default 257
style_token_nums (int): Number of style tokens, default 256
siglip_token_dims (int): Dimension of SigLIP tokens, default 1536
hidden_size (int): Hidden layer size, default 3072
context_layer_norm (bool): Whether to use context layer normalization, default False
"""
def __init__(
self,
siglip_token_nums: int = 729,
style_token_nums: int = 64,
siglip_token_dims: int = 1152,
hidden_size: int = 3072,
context_layer_norm: bool = True,
device=None, dtype=None, operations=None
):
super().__init__()
# High-level feature processing (layer -2)
self.high_embedding_linear = nn.Sequential(
operations.Linear(siglip_token_nums, style_token_nums),
nn.SiLU()
)
self.high_layer_norm = (
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
self.high_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)
# Mid-level feature processing (layer -11)
self.mid_embedding_linear = nn.Sequential(
operations.Linear(siglip_token_nums, style_token_nums),
nn.SiLU()
)
self.mid_layer_norm = (
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
self.mid_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)
# Low-level feature processing (layer -20)
self.low_embedding_linear = nn.Sequential(
operations.Linear(siglip_token_nums, style_token_nums),
nn.SiLU()
)
self.low_layer_norm = (
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
self.low_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)
def forward(self, siglip_outputs):
"""
Forward pass function
Args:
siglip_outputs: Output from SigLIP model, containing hidden_states
Returns:
torch.Tensor: Concatenated multi-layer features with shape [bs, 3*style_token_nums, hidden_size]
"""
dtype = next(self.high_embedding_linear.parameters()).dtype
# Process high-level features (layer -2)
high_embedding = self._process_layer_features(
siglip_outputs[2],
self.high_embedding_linear,
self.high_layer_norm,
self.high_projection,
dtype
)
# Process mid-level features (layer -11)
mid_embedding = self._process_layer_features(
siglip_outputs[1],
self.mid_embedding_linear,
self.mid_layer_norm,
self.mid_projection,
dtype
)
# Process low-level features (layer -20)
low_embedding = self._process_layer_features(
siglip_outputs[0],
self.low_embedding_linear,
self.low_layer_norm,
self.low_projection,
dtype
)
# Concatenate features from all layersmodel_patch
return torch.cat((high_embedding, mid_embedding, low_embedding), dim=1)
def _process_layer_features(
self,
hidden_states: torch.Tensor,
embedding_linear: nn.Module,
layer_norm: nn.Module,
projection: nn.Module,
dtype: torch.dtype
) -> torch.Tensor:
"""
Helper function to process features from a single layer
Args:
hidden_states: Input hidden states [bs, seq_len, dim]
embedding_linear: Embedding linear layer
layer_norm: Layer normalization
projection: Projection layer
dtype: Target data type
Returns:
torch.Tensor: Processed features [bs, style_token_nums, hidden_size]
"""
# Transform dimensions: [bs, seq_len, dim] -> [bs, dim, seq_len] -> [bs, dim, style_token_nums] -> [bs, style_token_nums, dim]
embedding = embedding_linear(
hidden_states.to(dtype).transpose(1, 2)
).transpose(1, 2)
# Apply layer normalization
embedding = layer_norm(embedding)
# Project to target hidden space
embedding = projection(embedding)
return embedding
class ModelPatchLoader:
@classmethod
def INPUT_TYPES(s):
@@ -204,14 +73,9 @@ class ModelPatchLoader:
model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name)
sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True)
dtype = comfy.utils.weight_dtype(sd)
if 'controlnet_blocks.0.y_rms.weight' in sd:
additional_in_dim = sd["img_in.weight"].shape[1] - 64
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
elif 'feature_embedder.mid_layer_norm.bias' in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
# TODO: this node will work with more types of model patches
additional_in_dim = sd["img_in.weight"].shape[1] - 64
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
return (model,)
@@ -293,51 +157,7 @@ class QwenImageDiffsynthControlnet:
return (model_patched,)
class UsoStyleProjectorPatch:
def __init__(self, model_patch, encoded_image):
self.model_patch = model_patch
self.encoded_image = encoded_image
def __call__(self, kwargs):
txt_ids = kwargs.get("txt_ids")
txt = kwargs.get("txt")
siglip_embedding = self.model_patch.model(self.encoded_image.to(txt.dtype)).to(txt.dtype)
txt = torch.cat([siglip_embedding, txt], dim=1)
kwargs['txt'] = txt
kwargs['txt_ids'] = torch.cat([torch.zeros(siglip_embedding.shape[0], siglip_embedding.shape[1], 3, dtype=txt_ids.dtype, device=txt_ids.device), txt_ids], dim=1)
return kwargs
def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
self.encoded_image = self.encoded_image.to(device_or_dtype)
return self
def models(self):
return [self.model_patch]
class USOStyleReference:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
"model_patch": ("MODEL_PATCH",),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_patch"
EXPERIMENTAL = True
CATEGORY = "advanced/model_patches/flux"
def apply_patch(self, model, model_patch, clip_vision_output):
encoded_image = torch.stack((clip_vision_output.all_hidden_states[:, -20], clip_vision_output.all_hidden_states[:, -11], clip_vision_output.penultimate_hidden_states))
model_patched = model.clone()
model_patched.set_model_post_input_patch(UsoStyleProjectorPatch(model_patch, encoded_image))
return (model_patched,)
NODE_CLASS_MAPPINGS = {
"ModelPatchLoader": ModelPatchLoader,
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
"USOStyleReference": USOStyleReference,
}

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.59"
__version__ = "0.3.56"

View File

@@ -1 +0,0 @@
"""Server middleware modules"""

View File

@@ -1,52 +0,0 @@
"""Cache control middleware for ComfyUI server"""
from aiohttp import web
from typing import Callable, Awaitable
# Time in seconds
ONE_HOUR: int = 3600
ONE_DAY: int = 86400
IMG_EXTENSIONS = (
".jpg",
".jpeg",
".png",
".ppm",
".bmp",
".pgm",
".tif",
".tiff",
".webp",
)
@web.middleware
async def cache_control(
request: web.Request, handler: Callable[[web.Request], Awaitable[web.Response]]
) -> web.Response:
"""Cache control middleware that sets appropriate cache headers based on file type and response status"""
response: web.Response = await handler(request)
if (
request.path.endswith(".js")
or request.path.endswith(".css")
or request.path.endswith("index.json")
):
response.headers.setdefault("Cache-Control", "no-cache")
return response
# Early return for non-image files - no cache headers needed
if not request.path.lower().endswith(IMG_EXTENSIONS):
return response
# Handle image files
if response.status == 404:
response.headers.setdefault("Cache-Control", f"public, max-age={ONE_HOUR}")
elif response.status in (200, 201, 202, 203, 204, 205, 206, 301, 308):
# Success responses and permanent redirects - cache for 1 day
response.headers.setdefault("Cache-Control", f"public, max-age={ONE_DAY}")
elif response.status in (302, 303, 307):
# Temporary redirects - no cache
response.headers.setdefault("Cache-Control", "no-cache")
# Note: 304 Not Modified falls through - no cache headers set
return response

View File

@@ -925,7 +925,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@@ -953,7 +953,7 @@ class DualCLIPLoader:
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image"], ),
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@@ -963,7 +963,7 @@ class DualCLIPLoader:
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small"
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama"
def load_clip(self, clip_name1, clip_name2, type, device="default"):
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
@@ -2344,7 +2344,6 @@ async def init_builtin_api_nodes():
"nodes_veo2.py",
"nodes_kling.py",
"nodes_bfl.py",
"nodes_bytedance.py",
"nodes_luma.py",
"nodes_recraft.py",
"nodes_pixverse.py",

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.59"
version = "0.3.56"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.25.11
comfyui-workflow-templates==0.1.81
comfyui-workflow-templates==0.1.70
comfyui-embedded-docs==0.2.6
torch
torchsde

View File

@@ -3,7 +3,11 @@ from urllib import request
#This is the ComfyUI api prompt format.
#If you want it for a specific workflow you can "File -> Export (API)" in the interface.
#If you want it for a specific workflow you can "enable dev mode options"
#in the settings of the UI (gear beside the "Queue Size: ") this will enable
#a button on the UI to save workflows in api format.
#keep in mind ComfyUI is pre alpha software so this format will change a bit.
#this is the one for the default workflow
prompt_text = """

View File

@@ -39,15 +39,20 @@ from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
from protocol import BinaryEventTypes
# Import cache control middleware
from middleware.cache_middleware import cache_control
async def send_socket_catch_exception(function, message):
try:
await function(message)
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
logging.warning("send error: {}".format(err))
@web.middleware
async def cache_control(request: web.Request, handler):
response: web.Response = await handler(request)
if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'):
response.headers.setdefault('Cache-Control', 'no-cache')
return response
@web.middleware
async def compress_body(request: web.Request, handler):
accept_encoding = request.headers.get("Accept-Encoding", "")
@@ -724,34 +729,7 @@ class PromptServer():
@routes.post("/interrupt")
async def post_interrupt(request):
try:
json_data = await request.json()
except json.JSONDecodeError:
json_data = {}
# Check if a specific prompt_id was provided for targeted interruption
prompt_id = json_data.get('prompt_id')
if prompt_id:
currently_running, _ = self.prompt_queue.get_current_queue()
# Check if the prompt_id matches any currently running prompt
should_interrupt = False
for item in currently_running:
# item structure: (number, prompt_id, prompt, extra_data, outputs_to_execute)
if item[1] == prompt_id:
logging.info(f"Interrupting prompt {prompt_id}")
should_interrupt = True
break
if should_interrupt:
nodes.interrupt_processing()
else:
logging.info(f"Prompt {prompt_id} is not currently running, skipping interrupt")
else:
# No prompt_id provided, do a global interrupt
logging.info("Global interrupt (no prompt_id specified)")
nodes.interrupt_processing()
nodes.interrupt_processing()
return web.Response(status=200)
@routes.post("/free")

View File

@@ -1,255 +0,0 @@
"""Tests for server cache control middleware"""
import pytest
from aiohttp import web
from aiohttp.test_utils import make_mocked_request
from typing import Dict, Any
from middleware.cache_middleware import cache_control, ONE_HOUR, ONE_DAY, IMG_EXTENSIONS
pytestmark = pytest.mark.asyncio # Apply asyncio mark to all tests
# Test configuration data
CACHE_SCENARIOS = [
# Image file scenarios
{
"name": "image_200_status",
"path": "/test.jpg",
"status": 200,
"expected_cache": f"public, max-age={ONE_DAY}",
"should_have_header": True,
},
{
"name": "image_404_status",
"path": "/missing.jpg",
"status": 404,
"expected_cache": f"public, max-age={ONE_HOUR}",
"should_have_header": True,
},
# JavaScript/CSS scenarios
{
"name": "js_no_cache",
"path": "/script.js",
"status": 200,
"expected_cache": "no-cache",
"should_have_header": True,
},
{
"name": "css_no_cache",
"path": "/styles.css",
"status": 200,
"expected_cache": "no-cache",
"should_have_header": True,
},
{
"name": "index_json_no_cache",
"path": "/api/index.json",
"status": 200,
"expected_cache": "no-cache",
"should_have_header": True,
},
# Non-matching files
{
"name": "html_no_header",
"path": "/index.html",
"status": 200,
"expected_cache": None,
"should_have_header": False,
},
{
"name": "txt_no_header",
"path": "/data.txt",
"status": 200,
"expected_cache": None,
"should_have_header": False,
},
{
"name": "api_endpoint_no_header",
"path": "/api/endpoint",
"status": 200,
"expected_cache": None,
"should_have_header": False,
},
{
"name": "pdf_no_header",
"path": "/file.pdf",
"status": 200,
"expected_cache": None,
"should_have_header": False,
},
]
# Status code scenarios for images
IMAGE_STATUS_SCENARIOS = [
# Success statuses get long cache
{"status": 200, "expected": f"public, max-age={ONE_DAY}"},
{"status": 201, "expected": f"public, max-age={ONE_DAY}"},
{"status": 202, "expected": f"public, max-age={ONE_DAY}"},
{"status": 204, "expected": f"public, max-age={ONE_DAY}"},
{"status": 206, "expected": f"public, max-age={ONE_DAY}"},
# Permanent redirects get long cache
{"status": 301, "expected": f"public, max-age={ONE_DAY}"},
{"status": 308, "expected": f"public, max-age={ONE_DAY}"},
# Temporary redirects get no cache
{"status": 302, "expected": "no-cache"},
{"status": 303, "expected": "no-cache"},
{"status": 307, "expected": "no-cache"},
# 404 gets short cache
{"status": 404, "expected": f"public, max-age={ONE_HOUR}"},
]
# Case sensitivity test paths
CASE_SENSITIVITY_PATHS = ["/image.JPG", "/photo.PNG", "/pic.JpEg"]
# Edge case test paths
EDGE_CASE_PATHS = [
{
"name": "query_strings_ignored",
"path": "/image.jpg?v=123&size=large",
"expected": f"public, max-age={ONE_DAY}",
},
{
"name": "multiple_dots_in_path",
"path": "/image.min.jpg",
"expected": f"public, max-age={ONE_DAY}",
},
{
"name": "nested_paths_with_images",
"path": "/static/images/photo.jpg",
"expected": f"public, max-age={ONE_DAY}",
},
]
class TestCacheControl:
"""Test cache control middleware functionality"""
@pytest.fixture
def status_handler_factory(self):
"""Create a factory for handlers that return specific status codes"""
def factory(status: int, headers: Dict[str, str] = None):
async def handler(request):
return web.Response(status=status, headers=headers or {})
return handler
return factory
@pytest.fixture
def mock_handler(self, status_handler_factory):
"""Create a mock handler that returns a response with 200 status"""
return status_handler_factory(200)
@pytest.fixture
def handler_with_existing_cache(self, status_handler_factory):
"""Create a handler that returns response with existing Cache-Control header"""
return status_handler_factory(200, {"Cache-Control": "max-age=3600"})
async def assert_cache_header(
self,
response: web.Response,
expected_cache: str = None,
should_have_header: bool = True,
):
"""Helper to assert cache control headers"""
if should_have_header:
assert "Cache-Control" in response.headers
if expected_cache:
assert response.headers["Cache-Control"] == expected_cache
else:
assert "Cache-Control" not in response.headers
# Parameterized tests
@pytest.mark.parametrize("scenario", CACHE_SCENARIOS, ids=lambda x: x["name"])
async def test_cache_control_scenarios(
self, scenario: Dict[str, Any], status_handler_factory
):
"""Test various cache control scenarios"""
handler = status_handler_factory(scenario["status"])
request = make_mocked_request("GET", scenario["path"])
response = await cache_control(request, handler)
assert response.status == scenario["status"]
await self.assert_cache_header(
response, scenario["expected_cache"], scenario["should_have_header"]
)
@pytest.mark.parametrize("ext", IMG_EXTENSIONS)
async def test_all_image_extensions(self, ext: str, mock_handler):
"""Test all defined image extensions are handled correctly"""
request = make_mocked_request("GET", f"/image{ext}")
response = await cache_control(request, mock_handler)
assert response.status == 200
assert "Cache-Control" in response.headers
assert response.headers["Cache-Control"] == f"public, max-age={ONE_DAY}"
@pytest.mark.parametrize(
"status_scenario", IMAGE_STATUS_SCENARIOS, ids=lambda x: f"status_{x['status']}"
)
async def test_image_status_codes(
self, status_scenario: Dict[str, Any], status_handler_factory
):
"""Test different status codes for image requests"""
handler = status_handler_factory(status_scenario["status"])
request = make_mocked_request("GET", "/image.jpg")
response = await cache_control(request, handler)
assert response.status == status_scenario["status"]
assert "Cache-Control" in response.headers
assert response.headers["Cache-Control"] == status_scenario["expected"]
@pytest.mark.parametrize("path", CASE_SENSITIVITY_PATHS)
async def test_case_insensitive_image_extension(self, path: str, mock_handler):
"""Test that image extensions are matched case-insensitively"""
request = make_mocked_request("GET", path)
response = await cache_control(request, mock_handler)
assert "Cache-Control" in response.headers
assert response.headers["Cache-Control"] == f"public, max-age={ONE_DAY}"
@pytest.mark.parametrize("edge_case", EDGE_CASE_PATHS, ids=lambda x: x["name"])
async def test_edge_cases(self, edge_case: Dict[str, str], mock_handler):
"""Test edge cases like query strings, nested paths, etc."""
request = make_mocked_request("GET", edge_case["path"])
response = await cache_control(request, mock_handler)
assert "Cache-Control" in response.headers
assert response.headers["Cache-Control"] == edge_case["expected"]
# Header preservation tests (special cases not covered by parameterization)
async def test_js_preserves_existing_headers(self, handler_with_existing_cache):
"""Test that .js files preserve existing Cache-Control headers"""
request = make_mocked_request("GET", "/script.js")
response = await cache_control(request, handler_with_existing_cache)
# setdefault should preserve existing header
assert response.headers["Cache-Control"] == "max-age=3600"
async def test_css_preserves_existing_headers(self, handler_with_existing_cache):
"""Test that .css files preserve existing Cache-Control headers"""
request = make_mocked_request("GET", "/styles.css")
response = await cache_control(request, handler_with_existing_cache)
# setdefault should preserve existing header
assert response.headers["Cache-Control"] == "max-age=3600"
async def test_image_preserves_existing_headers(self, status_handler_factory):
"""Test that image cache headers preserve existing Cache-Control"""
handler = status_handler_factory(200, {"Cache-Control": "private, no-cache"})
request = make_mocked_request("GET", "/image.jpg")
response = await cache_control(request, handler)
# setdefault should preserve existing header
assert response.headers["Cache-Control"] == "private, no-cache"
async def test_304_not_modified_inherits_cache(self, status_handler_factory):
"""Test that 304 Not Modified doesn't set cache headers for images"""
handler = status_handler_factory(304, {"Cache-Control": "max-age=7200"})
request = make_mocked_request("GET", "/not-modified.jpg")
response = await cache_control(request, handler)
assert response.status == 304
# Should preserve existing cache header, not override
assert response.headers["Cache-Control"] == "max-age=7200"