Compare commits

..

21 Commits

Author SHA1 Message Date
comfyanonymous
b0338e930b ComfyUI 0.3.57 2025-09-04 02:15:57 -04:00
ComfyUI Wiki
b71f9bcb71 Update template to 0.1.75 (#9711) 2025-09-04 02:14:02 -04:00
comfyanonymous
72855db715 Fix potential rope issue. (#9710) 2025-09-03 22:20:13 -04:00
Alexander Piskun
f48d05a2d1 convert AlignYourStepsScheduler node to V3 schema (#9226) 2025-09-03 21:21:38 -04:00
comfyanonymous
4368d8f87f Update comment in api example. (#9708) 2025-09-03 18:43:29 -04:00
Alexander Piskun
22da0a83e9 [V3] convert Runway API nodes to the V3 schema (#9487)
* convert RunAway API nodes to the V3 schema

* fixed small typo

* fix: add tooltip for "seed" input
2025-09-03 16:18:27 -04:00
Alexander Piskun
50333f1715 api nodes(Ideogram): add Ideogram Character (#9616)
* api nodes(Ideogram): add Ideogram Character

* rename renderingSpeed default value from 'balanced' to 'default'
2025-09-03 16:17:37 -04:00
Alexander Piskun
26d5b86da8 feat(api-nodes): add ByteDance Image nodes (#9477) 2025-09-03 16:17:07 -04:00
ComfyUI Wiki
4f5812b937 Update template to 0.1.73 (#9686) 2025-09-02 20:06:41 -04:00
comfyanonymous
1bcb469089 ImageScaleToMaxDimension node. (#9689) 2025-09-02 20:05:57 -04:00
Deep Roy
464ba1d614 Accept prompt_id in interrupt handler (#9607)
* Accept prompt_id in interrupt handler

* remove a log
2025-09-02 19:41:10 -04:00
comfyanonymous
e3018c2a5a uso -> uxo/uno as requested. (#9688) 2025-09-02 16:12:07 -04:00
comfyanonymous
3412d53b1d USO style reference. (#9677)
Load the projector.safetensors file with the ModelPatchLoader node and use
the siglip_vision_patch14_384.safetensors "clip vision" model and the
USOStyleReferenceNode.
2025-09-02 15:36:22 -04:00
contentis
e2d1e5dad9 Enable Convolution AutoTuning (#9301) 2025-09-01 20:33:50 -04:00
comfyanonymous
27e067ce50 Implement the USO subject identity lora. (#9674)
Use the lora with FluxContextMultiReferenceLatentMethod node set to "uso"
and a ReferenceLatent node with the reference image.
2025-09-01 18:54:02 -04:00
comfyanonymous
9b15155972 Probably not necessary anymore. (#9646) 2025-08-31 01:32:10 -04:00
chaObserv
32a627bf1f SEEDS: update noise decomposition and refactor (#9633)
- Update the decomposition to reflect interval dependency
- Extract phi computations into functions
- Use torch.lerp for interpolation
2025-08-31 00:01:45 -04:00
Alexander Piskun
fe442fac2e convert Primitive nodes to V3 schema (#9372) 2025-08-30 23:21:58 -04:00
Alexander Piskun
d2c502e629 convert nodes_stability.py to V3 schema (#9497) 2025-08-30 23:20:17 -04:00
Alexander Piskun
fea9ea8268 convert Video nodes to V3 schema (#9489) 2025-08-30 23:19:54 -04:00
Alexander Piskun
f949094b3c convert Stable Cascade nodes to V3 schema (#9373) 2025-08-30 23:19:21 -04:00
55 changed files with 2075 additions and 2399 deletions

View File

@@ -143,6 +143,7 @@ class PerformanceFeature(enum.Enum):
Fp16Accumulation = "fp16_accumulation"
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
AutoTune = "autotune"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")

View File

@@ -61,8 +61,12 @@ class CLIPEncoder(torch.nn.Module):
def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
all_intermediate = None
if intermediate_output is not None:
if intermediate_output < 0:
if intermediate_output == "all":
all_intermediate = []
intermediate_output = None
elif intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
@@ -70,6 +74,12 @@ class CLIPEncoder(torch.nn.Module):
x = l(x, mask, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
if all_intermediate is not None:
intermediate = torch.cat(all_intermediate, dim=1)
return x, intermediate
class CLIPEmbeddings(torch.nn.Module):

View File

@@ -50,7 +50,13 @@ class ClipVisionModel():
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
model_type = config.get("model_type", "clip_vision_model")
model_class = IMAGE_ENCODERS.get(model_type)
if model_type == "siglip_vision_model":
self.return_all_hidden_states = True
else:
self.return_all_hidden_states = False
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
@@ -68,12 +74,18 @@ class ClipVisionModel():
def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
outputs = Output()
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
if self.return_all_hidden_states:
all_hs = out[1].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = all_hs[:, -2]
outputs["all_hidden_states"] = all_hs
else:
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
outputs["mm_projected"] = out[3]
return outputs

View File

@@ -171,6 +171,16 @@ def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
return sigmas
def ei_h_phi_1(h: torch.Tensor) -> torch.Tensor:
"""Compute the result of h*phi_1(h) in exponential integrator methods."""
return torch.expm1(h)
def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor:
"""Compute the result of h*phi_2(h) in exponential integrator methods."""
return (torch.expm1(h) - h) / h
@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
@@ -1550,13 +1560,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
@@ -1564,55 +1573,53 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
fac = 1 / (2 * r)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = lambda_s + r * h
fac = 1 / (2 * r)
sigma_s_1 = sigma_fn(lambda_s_1)
continue
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r)
sigma_s_1 = sigma_fn(lambda_s_1)
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
# 0 < r < 1
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised
if inject_noise:
sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
# Step 2
denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
if inject_noise:
segment_factor = (r - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp()
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1])
x = x + sde_noise * sigmas[i + 1] * s_noise
return x
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
arXiv: https://arxiv.org/abs/2305.14267
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
@@ -1624,45 +1631,49 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = lambda_s + r_1 * h
lambda_s_2 = lambda_s + r_2 * h
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
continue
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r_1)
lambda_s_2 = torch.lerp(lambda_s, lambda_t, r_2)
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
# 0 < r_1 < r_2 < 1
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised
if inject_noise:
sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
if inject_noise:
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
# Step 2
a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta)
a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2)
if inject_noise:
segment_factor = (r_1 - r_2) * h * eta
sde_noise = sde_noise * segment_factor.exp()
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2)
x_3 = x_3 + sde_noise * sigma_s_2 * s_noise
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
# Step 3
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
# Step 3
b3 = ei_h_phi_2(-h_eta) / r_2
b1 = ei_h_phi_1(-h_eta) - b3
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3)
if inject_noise:
segment_factor = (r_2 - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp()
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1])
x = x + sde_noise * sigmas[i + 1] * s_noise
return x

View File

@@ -133,7 +133,6 @@ class Attention(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
transformer_options={},
**cross_attention_kwargs,
) -> torch.Tensor:
return self.processor(
@@ -141,7 +140,6 @@ class Attention(nn.Module):
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
transformer_options=transformer_options,
**cross_attention_kwargs,
)
@@ -368,7 +366,6 @@ class CustomerAttnProcessor2_0:
encoder_attention_mask: Optional[torch.FloatTensor] = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
transformer_options={},
*args,
**kwargs,
) -> torch.Tensor:
@@ -436,7 +433,7 @@ class CustomerAttnProcessor2_0:
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = optimized_attention(
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options,
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
).to(query.dtype)
# linear proj
@@ -700,7 +697,6 @@ class LinearTransformerBlock(nn.Module):
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
temb: torch.FloatTensor = None,
transformer_options={},
):
N = hidden_states.shape[0]
@@ -724,7 +720,6 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
transformer_options=transformer_options,
)
else:
attn_output, _ = self.attn(
@@ -734,7 +729,6 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=None,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=None,
transformer_options=transformer_options,
)
if self.use_adaln_single:
@@ -749,7 +743,6 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
transformer_options=transformer_options,
)
hidden_states = attn_output + hidden_states

View File

@@ -314,7 +314,6 @@ class ACEStepTransformer2DModel(nn.Module):
output_length: int = 0,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
transformer_options={},
):
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
temb = self.t_block(embedded_timestep)
@@ -340,7 +339,6 @@ class ACEStepTransformer2DModel(nn.Module):
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
temb=temb,
transformer_options=transformer_options,
)
output = self.final_layer(hidden_states, embedded_timestep, output_length)
@@ -395,7 +393,6 @@ class ACEStepTransformer2DModel(nn.Module):
output_length = hidden_states.shape[-1]
transformer_options = kwargs.get("transformer_options", {})
output = self.decode(
hidden_states=hidden_states,
attention_mask=attention_mask,
@@ -405,7 +402,6 @@ class ACEStepTransformer2DModel(nn.Module):
output_length=output_length,
block_controlnet_hidden_states=block_controlnet_hidden_states,
controlnet_scale=controlnet_scale,
transformer_options=transformer_options,
)
return output

View File

@@ -298,8 +298,7 @@ class Attention(nn.Module):
mask = None,
context_mask = None,
rotary_pos_emb = None,
causal = None,
transformer_options={},
causal = None
):
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
@@ -364,7 +363,7 @@ class Attention(nn.Module):
heads_per_kv_head = h // kv_h
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
out = optimized_attention(q, k, v, h, skip_reshape=True)
out = self.to_out(out)
if mask is not None:
@@ -489,8 +488,7 @@ class TransformerBlock(nn.Module):
global_cond=None,
mask = None,
context_mask = None,
rotary_pos_emb = None,
transformer_options={}
rotary_pos_emb = None
):
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
@@ -500,12 +498,12 @@ class TransformerBlock(nn.Module):
residual = x
x = self.pre_norm(x)
x = x * (1 + scale_self) + shift_self
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
x = x * torch.sigmoid(1 - gate_self)
x = x + residual
if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
if self.conformer is not None:
x = x + self.conformer(x)
@@ -519,10 +517,10 @@ class TransformerBlock(nn.Module):
x = x + residual
else:
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
if self.conformer is not None:
x = x + self.conformer(x)
@@ -608,8 +606,7 @@ class ContinuousTransformer(nn.Module):
return_info = False,
**kwargs
):
transformer_options = kwargs.get("transformer_options", {})
patches_replace = transformer_options.get("patches_replace", {})
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
batch, seq, device = *x.shape[:2], x.device
context = kwargs["context"]
@@ -635,7 +632,7 @@ class ContinuousTransformer(nn.Module):
# Attention layers
if self.rotary_pos_emb is not None:
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=torch.float, device=x.device)
else:
rotary_pos_emb = None
@@ -648,13 +645,13 @@ class ContinuousTransformer(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"])
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
x = out["img"]
else:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options)
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
if return_info:

View File

@@ -85,7 +85,7 @@ class SingleAttention(nn.Module):
)
#@torch.compile()
def forward(self, c, transformer_options={}):
def forward(self, c):
bsz, seqlen1, _ = c.shape
@@ -95,7 +95,7 @@ class SingleAttention(nn.Module):
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
q, k = self.q_norm1(q), self.k_norm1(k)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
c = self.w1o(output)
return c
@@ -144,7 +144,7 @@ class DoubleAttention(nn.Module):
#@torch.compile()
def forward(self, c, x, transformer_options={}):
def forward(self, c, x):
bsz, seqlen1, _ = c.shape
bsz, seqlen2, _ = x.shape
@@ -168,7 +168,7 @@ class DoubleAttention(nn.Module):
torch.cat([cv, xv], dim=1),
)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
c, x = output.split([seqlen1, seqlen2], dim=1)
c = self.w1o(c)
@@ -207,7 +207,7 @@ class MMDiTBlock(nn.Module):
self.is_last = is_last
#@torch.compile()
def forward(self, c, x, global_cond, transformer_options={}, **kwargs):
def forward(self, c, x, global_cond, **kwargs):
cres, xres = c, x
@@ -225,7 +225,7 @@ class MMDiTBlock(nn.Module):
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
# attention
c, x = self.attn(c, x, transformer_options=transformer_options)
c, x = self.attn(c, x)
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
@@ -255,13 +255,13 @@ class DiTBlock(nn.Module):
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
#@torch.compile()
def forward(self, cx, global_cond, transformer_options={}, **kwargs):
def forward(self, cx, global_cond, **kwargs):
cxres = cx
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
global_cond
).chunk(6, dim=1)
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
cx = self.attn(cx, transformer_options=transformer_options)
cx = self.attn(cx)
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
cx = gate_mlp.unsqueeze(1) * mlpout
@@ -473,14 +473,13 @@ class MMDiT(nn.Module):
out = {}
out["txt"], out["img"] = layer(args["txt"],
args["img"],
args["vec"],
transformer_options=args["transformer_options"])
args["vec"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
c = out["txt"]
x = out["img"]
else:
c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs)
c, x = layer(c, x, global_cond, **kwargs)
if len(self.single_layers) > 0:
c_len = c.size(1)
@@ -489,13 +488,13 @@ class MMDiT(nn.Module):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"])
out["img"] = layer(args["img"], args["vec"])
return out
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
cx = out["img"]
else:
cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs)
cx = layer(cx, global_cond, **kwargs)
x = cx[:, c_len:]

View File

@@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module):
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
def forward(self, q, k, v, transformer_options={}):
def forward(self, q, k, v):
q = self.to_q(q)
k = self.to_k(k)
v = self.to_v(v)
out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options)
out = optimized_attention(q, k, v, self.heads)
return self.out_proj(out)
@@ -47,13 +47,13 @@ class Attention2D(nn.Module):
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
def forward(self, x, kv, self_attn=False, transformer_options={}):
def forward(self, x, kv, self_attn=False):
orig_shape = x.shape
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
if self_attn:
kv = torch.cat([x, kv], dim=1)
# x = self.attn(x, kv, kv, need_weights=False)[0]
x = self.attn(x, kv, kv, transformer_options=transformer_options)
x = self.attn(x, kv, kv)
x = x.permute(0, 2, 1).view(*orig_shape)
return x
@@ -114,9 +114,9 @@ class AttnBlock(nn.Module):
operations.Linear(c_cond, c, dtype=dtype, device=device)
)
def forward(self, x, kv, transformer_options={}):
def forward(self, x, kv):
kv = self.kv_mapper(kv)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
return x

View File

@@ -173,7 +173,7 @@ class StageB(nn.Module):
clip = self.clip_norm(clip)
return clip
def _down_encode(self, x, r_embed, clip, transformer_options={}):
def _down_encode(self, x, r_embed, clip):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
@@ -187,7 +187,7 @@ class StageB(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip, transformer_options=transformer_options)
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@@ -199,7 +199,7 @@ class StageB(nn.Module):
level_outputs.insert(0, x)
return level_outputs
def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}):
def _up_decode(self, level_outputs, r_embed, clip):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
@@ -216,7 +216,7 @@ class StageB(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip, transformer_options=transformer_options)
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@@ -228,7 +228,7 @@ class StageB(nn.Module):
x = upscaler(x)
return x
def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs):
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
if pixels is None:
pixels = x.new_zeros(x.size(0), 3, 8, 8)
@@ -245,8 +245,8 @@ class StageB(nn.Module):
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
align_corners=True)
level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options)
x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options)
level_outputs = self._down_encode(x, r_embed, clip)
x = self._up_decode(level_outputs, r_embed, clip)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):

View File

@@ -182,7 +182,7 @@ class StageC(nn.Module):
clip = self.clip_norm(clip)
return clip
def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}):
def _down_encode(self, x, r_embed, clip, cnet=None):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
@@ -201,7 +201,7 @@ class StageC(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip, transformer_options=transformer_options)
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@@ -213,7 +213,7 @@ class StageC(nn.Module):
level_outputs.insert(0, x)
return level_outputs
def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}):
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
@@ -235,7 +235,7 @@ class StageC(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip, transformer_options=transformer_options)
x = block(x, clip)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@@ -247,7 +247,7 @@ class StageC(nn.Module):
x = upscaler(x)
return x
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs):
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
# Process the conditioning embeddings
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
for c in self.t_conds:
@@ -262,8 +262,8 @@ class StageC(nn.Module):
# Model Blocks
x = self.embedding(x)
level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options)
x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options)
level_outputs = self._down_encode(x, r_embed, clip, cnet)
x = self._up_decode(level_outputs, r_embed, clip, cnet)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):

View File

@@ -76,7 +76,7 @@ class DoubleStreamBlock(nn.Module):
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
@@ -95,7 +95,7 @@ class DoubleStreamBlock(nn.Module):
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask, transformer_options=transformer_options)
pe=pe, mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
@@ -148,7 +148,7 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh")
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
mod = vec
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@@ -157,7 +157,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x.addcmul_(mod.gate, output)

View File

@@ -193,16 +193,14 @@ class Chroma(nn.Module):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": double_mod,
"pe": pe,
"attn_mask": attn_mask,
"transformer_options": transformer_options},
"attn_mask": attn_mask},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
@@ -211,8 +209,7 @@ class Chroma(nn.Module):
txt=txt,
vec=double_mod,
pe=pe,
attn_mask=attn_mask,
transformer_options=transformer_options)
attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
@@ -232,19 +229,17 @@ class Chroma(nn.Module):
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": single_mod,
"pe": pe,
"attn_mask": attn_mask,
"transformer_options": transformer_options},
"attn_mask": attn_mask},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")

View File

@@ -176,7 +176,6 @@ class Attention(nn.Module):
context=None,
mask=None,
rope_emb=None,
transformer_options={},
**kwargs,
):
"""
@@ -185,7 +184,7 @@ class Attention(nn.Module):
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options)
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
del q, k, v
out = rearrange(out, " b n s c -> s b (n c)")
return self.to_out(out)
@@ -547,7 +546,6 @@ class VideoAttn(nn.Module):
context: Optional[torch.Tensor] = None,
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
"""
Forward pass for video attention.
@@ -573,7 +571,6 @@ class VideoAttn(nn.Module):
context_M_B_D,
crossattn_mask,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
)
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
return x_T_H_W_B_D
@@ -668,7 +665,6 @@ class DITBuildingBlock(nn.Module):
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
"""
Forward pass for dynamically configured blocks with adaptive normalization.
@@ -706,7 +702,6 @@ class DITBuildingBlock(nn.Module):
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
context=None,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
transformer_options=transformer_options,
)
elif self.block_type in ["cross_attn", "ca"]:
x = x + gate_1_1_1_B_D * self.block(
@@ -714,7 +709,6 @@ class DITBuildingBlock(nn.Module):
context=crossattn_emb,
crossattn_mask=crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
transformer_options=transformer_options,
)
else:
raise ValueError(f"Unknown block type: {self.block_type}")
@@ -790,7 +784,6 @@ class GeneralDITTransformerBlock(nn.Module):
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
for block in self.blocks:
x = block(
@@ -800,6 +793,5 @@ class GeneralDITTransformerBlock(nn.Module):
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
transformer_options=transformer_options,
)
return x

View File

@@ -520,7 +520,6 @@ class GeneralDIT(nn.Module):
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
transformer_options = kwargs.get("transformer_options", {})
for _, block in self.blocks.items():
assert (
self.blocks["block0"].x_format == block.x_format
@@ -535,7 +534,6 @@ class GeneralDIT(nn.Module):
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
transformer_options=transformer_options,
)
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")

View File

@@ -44,7 +44,7 @@ class GPT2FeedForward(nn.Module):
return x
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
"""Computes multi-head attention using PyTorch's native implementation.
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
@@ -71,7 +71,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options)
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
class Attention(nn.Module):
@@ -180,8 +180,8 @@ class Attention(nn.Module):
return q, k, v
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D]
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
result = self.attn_op(q, k, v) # [B, S, H, D]
return self.output_dropout(self.output_proj(result))
def forward(
@@ -189,7 +189,6 @@ class Attention(nn.Module):
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
"""
Args:
@@ -197,7 +196,7 @@ class Attention(nn.Module):
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
return self.compute_attention(q, k, v, transformer_options=transformer_options)
return self.compute_attention(q, k, v)
class Timesteps(nn.Module):
@@ -460,7 +459,6 @@ class Block(nn.Module):
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
@@ -514,7 +512,6 @@ class Block(nn.Module):
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
None,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
),
"b (t h w) d -> b t h w d",
t=T,
@@ -528,7 +525,6 @@ class Block(nn.Module):
layer_norm_cross_attn: Callable,
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
_normalized_x_B_T_H_W_D = _fn(
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
@@ -538,7 +534,6 @@ class Block(nn.Module):
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
),
"b (t h w) d -> b t h w d",
t=T,
@@ -552,7 +547,6 @@ class Block(nn.Module):
self.layer_norm_cross_attn,
scale_cross_attn_B_T_1_1_D,
shift_cross_attn_B_T_1_1_D,
transformer_options=transformer_options,
)
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
@@ -871,7 +865,6 @@ class MiniTrainDIT(nn.Module):
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
"transformer_options": kwargs.get("transformer_options", {}),
}
for block in self.blocks:
x_B_T_H_W_D = block(

View File

@@ -159,7 +159,7 @@ class DoubleStreamBlock(nn.Module):
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
@@ -182,7 +182,7 @@ class DoubleStreamBlock(nn.Module):
attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
pe=pe, mask=attn_mask, transformer_options=transformer_options)
pe=pe, mask=attn_mask)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
@@ -190,7 +190,7 @@ class DoubleStreamBlock(nn.Module):
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask, transformer_options=transformer_options)
pe=pe, mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
@@ -244,7 +244,7 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
mod, _ = self.modulation(vec)
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@@ -252,7 +252,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += apply_mod(output, mod.gate, None, modulation_dims)

View File

@@ -6,7 +6,7 @@ from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q_shape = q.shape
k_shape = k.shape
@@ -17,7 +17,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
return x

View File

@@ -106,6 +106,7 @@ class Flux(nn.Module):
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -117,9 +118,17 @@ class Flux(nn.Module):
if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
txt = self.txt_in(txt)
if "post_input" in patches:
for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
img = out["img"]
txt = out["txt"]
img_ids = out["img_ids"]
txt_ids = out["txt_ids"]
if img_ids is not None:
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
@@ -128,7 +137,6 @@ class Flux(nn.Module):
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
transformer_options["block"] = ("double_block", i, 2)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@@ -136,16 +144,14 @@ class Flux(nn.Module):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask,
"transformer_options": transformer_options},
"attn_mask": attn_mask},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
@@ -154,8 +160,7 @@ class Flux(nn.Module):
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask,
transformer_options=transformer_options)
attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
@@ -170,26 +175,23 @@ class Flux(nn.Module):
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
transformer_options["block"] = ("single_block", i, 1)
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask,
"transformer_options": transformer_options},
"attn_mask": attn_mask},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")
@@ -240,12 +242,18 @@ class Flux(nn.Module):
h = 0
w = 0
index = 0
index_ref_method = kwargs.get("ref_latents_method", "offset") == "index"
ref_latents_method = kwargs.get("ref_latents_method", "offset")
for ref in ref_latents:
if index_ref_method:
if ref_latents_method == "index":
index += 1
h_offset = 0
w_offset = 0
elif ref_latents_method == "uxo":
index = 0
h_offset = h_len * patch_size + h
w_offset = w_len * patch_size + w
h += ref.shape[-2]
w += ref.shape[-1]
else:
index = 1
h_offset = 0

View File

@@ -109,7 +109,6 @@ class AsymmetricAttention(nn.Module):
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
crop_y,
transformer_options={},
**rope_rotation,
) -> Tuple[torch.Tensor, torch.Tensor]:
rope_cos = rope_rotation.get("rope_cos")
@@ -144,7 +143,7 @@ class AsymmetricAttention(nn.Module):
xy = optimized_attention(q,
k,
v, self.num_heads, skip_reshape=True, transformer_options=transformer_options)
v, self.num_heads, skip_reshape=True)
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
x = self.proj_x(x)
@@ -225,7 +224,6 @@ class AsymmetricJointBlock(nn.Module):
x: torch.Tensor,
c: torch.Tensor,
y: torch.Tensor,
transformer_options={},
**attn_kwargs,
):
"""Forward pass of a block.
@@ -258,7 +256,6 @@ class AsymmetricJointBlock(nn.Module):
y,
scale_x=scale_msa_x,
scale_y=scale_msa_y,
transformer_options=transformer_options,
**attn_kwargs,
)
@@ -527,11 +524,10 @@ class AsymmDiTJoint(nn.Module):
args["txt"],
rope_cos=args["rope_cos"],
rope_sin=args["rope_sin"],
crop_y=args["num_tokens"],
transformer_options=args["transformer_options"]
crop_y=args["num_tokens"]
)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
y_feat = out["txt"]
x = out["img"]
else:
@@ -542,7 +538,6 @@ class AsymmDiTJoint(nn.Module):
rope_cos=rope_cos,
rope_sin=rope_sin,
crop_y=num_tokens,
transformer_options=transformer_options,
) # (B, M, D), (B, L, D)
del y_feat # Final layers don't use dense text features.

View File

@@ -72,8 +72,8 @@ class TimestepEmbed(nn.Module):
return t_emb
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}):
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options)
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
class HiDreamAttnProcessor_flashattn:
@@ -86,7 +86,6 @@ class HiDreamAttnProcessor_flashattn:
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
transformer_options={},
*args,
**kwargs,
) -> torch.FloatTensor:
@@ -134,7 +133,7 @@ class HiDreamAttnProcessor_flashattn:
query = torch.cat([query_1, query_2], dim=-1)
key = torch.cat([key_1, key_2], dim=-1)
hidden_states = attention(query, key, value, transformer_options=transformer_options)
hidden_states = attention(query, key, value)
if not attn.single:
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
@@ -200,7 +199,6 @@ class HiDreamAttention(nn.Module):
image_tokens_masks: torch.FloatTensor = None,
norm_text_tokens: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.Tensor:
return self.processor(
self,
@@ -208,7 +206,6 @@ class HiDreamAttention(nn.Module):
image_tokens_masks = image_tokens_masks,
text_tokens = norm_text_tokens,
rope = rope,
transformer_options=transformer_options,
)
@@ -409,7 +406,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
@@ -422,7 +419,6 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
norm_image_tokens,
image_tokens_masks,
rope = rope,
transformer_options=transformer_options,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
@@ -487,7 +483,6 @@ class HiDreamImageTransformerBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
@@ -505,7 +500,6 @@ class HiDreamImageTransformerBlock(nn.Module):
image_tokens_masks,
norm_text_tokens,
rope = rope,
transformer_options=transformer_options,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
@@ -556,7 +550,6 @@ class HiDreamImageBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.FloatTensor:
return self.block(
image_tokens,
@@ -564,7 +557,6 @@ class HiDreamImageBlock(nn.Module):
text_tokens,
adaln_input,
rope,
transformer_options=transformer_options,
)
@@ -794,7 +786,6 @@ class HiDreamImageTransformer2DModel(nn.Module):
text_tokens = cur_encoder_hidden_states,
adaln_input = adaln_input,
rope = rope,
transformer_options=transformer_options,
)
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
block_id += 1
@@ -818,7 +809,6 @@ class HiDreamImageTransformer2DModel(nn.Module):
text_tokens=None,
adaln_input=adaln_input,
rope=rope,
transformer_options=transformer_options,
)
hidden_states = hidden_states[:, :hidden_states_seq_len]
block_id += 1

View File

@@ -99,16 +99,14 @@ class Hunyuan3Dv2(nn.Module):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args["transformer_options"])
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask,
"transformer_options": transformer_options},
"attn_mask": attn_mask},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
@@ -117,8 +115,7 @@ class Hunyuan3Dv2(nn.Module):
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask,
transformer_options=transformer_options)
attn_mask=attn_mask)
img = torch.cat((txt, img), 1)
@@ -129,19 +126,17 @@ class Hunyuan3Dv2(nn.Module):
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args["transformer_options"])
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask,
"transformer_options": transformer_options},
"attn_mask": attn_mask},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
img = img[:, txt.shape[1]:, ...]
img = self.final_layer(img, vec)

View File

@@ -78,13 +78,13 @@ class TokenRefinerBlock(nn.Module):
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, x, c, mask, transformer_options={}):
def forward(self, x, c, mask):
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn.qkv(norm_x)
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True, transformer_options=transformer_options)
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
@@ -115,14 +115,14 @@ class IndividualTokenRefiner(nn.Module):
]
)
def forward(self, x, c, mask, transformer_options={}):
def forward(self, x, c, mask):
m = None
if mask is not None:
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
m = m + m.transpose(2, 3)
for block in self.blocks:
x = block(x, c, m, transformer_options=transformer_options)
x = block(x, c, m)
return x
@@ -150,7 +150,6 @@ class TokenRefiner(nn.Module):
x,
timesteps,
mask,
transformer_options={},
):
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
# m = mask.float().unsqueeze(-1)
@@ -159,7 +158,7 @@ class TokenRefiner(nn.Module):
c = t + self.c_embedder(c.to(x.dtype))
x = self.input_embedder(x)
x = self.individual_token_refiner(x, c, mask, transformer_options=transformer_options)
x = self.individual_token_refiner(x, c, mask)
return x
class HunyuanVideo(nn.Module):
@@ -268,7 +267,7 @@ class HunyuanVideo(nn.Module):
if txt_mask is not None and not torch.is_floating_point(txt_mask):
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
txt = self.txt_in(txt, timesteps, txt_mask)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
@@ -286,14 +285,14 @@ class HunyuanVideo(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"], transformer_options=args["transformer_options"])
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt, 'transformer_options': transformer_options}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt, transformer_options=transformer_options)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
if control is not None: # Controlnet
control_i = control.get("input")
@@ -308,13 +307,13 @@ class HunyuanVideo(nn.Module):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"], transformer_options=args["transformer_options"])
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims, 'transformer_options': transformer_options}, {"original_block": block_wrap})
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims, transformer_options=transformer_options)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
if control is not None: # Controlnet
control_o = control.get("output")

View File

@@ -271,7 +271,7 @@ class CrossAttention(nn.Module):
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
def forward(self, x, context=None, mask=None, pe=None):
q = self.to_q(x)
context = x if context is None else context
k = self.to_k(context)
@@ -285,9 +285,9 @@ class CrossAttention(nn.Module):
k = apply_rotary_emb(k, pe)
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
return self.to_out(out)
@@ -303,12 +303,12 @@ class BasicTransformerBlock(nn.Module):
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
x += self.attn2(x, context=context, mask=attention_mask)
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp
@@ -479,10 +479,10 @@ class LTXVModel(torch.nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
@@ -490,8 +490,7 @@ class LTXVModel(torch.nn.Module):
context=context,
attention_mask=attention_mask,
timestep=timestep,
pe=pe,
transformer_options=transformer_options,
pe=pe
)
# 3. Output

View File

@@ -104,7 +104,6 @@ class JointAttention(nn.Module):
x: torch.Tensor,
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
transformer_options={},
) -> torch.Tensor:
"""
@@ -141,7 +140,7 @@ class JointAttention(nn.Module):
if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options)
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
return self.out(output)
@@ -269,7 +268,6 @@ class JointTransformerBlock(nn.Module):
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor]=None,
transformer_options={},
):
"""
Perform a forward pass through the TransformerBlock.
@@ -292,7 +290,6 @@ class JointTransformerBlock(nn.Module):
modulate(self.attention_norm1(x), scale_msa),
x_mask,
freqs_cis,
transformer_options=transformer_options,
)
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
@@ -307,7 +304,6 @@ class JointTransformerBlock(nn.Module):
self.attention_norm1(x),
x_mask,
freqs_cis,
transformer_options=transformer_options,
)
)
x = x + self.ffn_norm2(
@@ -498,7 +494,7 @@ class NextDiT(nn.Module):
return imgs
def patchify_and_embed(
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
bsz = len(x)
pH = pW = self.patch_size
@@ -558,7 +554,7 @@ class NextDiT(nn.Module):
# refine context
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
# refine image
flat_x = []
@@ -577,7 +573,7 @@ class NextDiT(nn.Module):
padded_img_embed = self.x_embedder(padded_img_embed)
padded_img_mask = padded_img_mask.unsqueeze(1)
for layer in self.noise_refiner:
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
if cap_mask is not None:
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
@@ -620,13 +616,12 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
transformer_options = kwargs.get("transformer_options", {})
x_is_tensor = isinstance(x, torch.Tensor)
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
freqs_cis = freqs_cis.to(x.device)
for layer in self.layers:
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
x = layer(x, mask, freqs_cis, adaln_input)
x = self.final_layer(x, adaln_input)
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]

View File

@@ -5,9 +5,8 @@ import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional, Any, Callable, Union
from typing import Optional
import logging
import functools
from .diffusionmodules.util import AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention
@@ -18,45 +17,23 @@ if model_management.xformers_enabled():
import xformers
import xformers.ops
SAGE_ATTENTION_IS_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTENTION_IS_AVAILABLE = True
except ModuleNotFoundError as e:
if model_management.sage_attention_enabled():
if model_management.sage_attention_enabled():
try:
from sageattention import sageattn
except ModuleNotFoundError as e:
if e.name == "sageattention":
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
else:
raise e
exit(-1)
FLASH_ATTENTION_IS_AVAILABLE = False
try:
from flash_attn import flash_attn_func
FLASH_ATTENTION_IS_AVAILABLE = True
except ModuleNotFoundError:
if model_management.flash_attention_enabled():
if model_management.flash_attention_enabled():
try:
from flash_attn import flash_attn_func
except ModuleNotFoundError:
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
REGISTERED_ATTENTION_FUNCTIONS = {}
def register_attention_function(name: str, func: Callable):
# avoid replacing existing functions
if name not in REGISTERED_ATTENTION_FUNCTIONS:
REGISTERED_ATTENTION_FUNCTIONS[name] = func
else:
logging.warning(f"Attention function {name} already registered, skipping registration.")
def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]:
if name == "optimized":
return optimized_attention
elif name not in REGISTERED_ATTENTION_FUNCTIONS:
if default is ...:
raise KeyError(f"Attention function {name} not found.")
else:
return default
return REGISTERED_ATTENTION_FUNCTIONS[name]
from comfy.cli_args import args
import comfy.ops
ops = comfy.ops.disable_weight_init
@@ -114,27 +91,7 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def wrap_attn(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
remove_attn_wrapper_key = False
try:
if "_inside_attn_wrapper" not in kwargs:
transformer_options = kwargs.get("transformer_options", None)
remove_attn_wrapper_key = True
kwargs["_inside_attn_wrapper"] = True
if transformer_options is not None:
if "optimized_attention_override" in transformer_options:
return transformer_options["optimized_attention_override"](func, *args, **kwargs)
return func(*args, **kwargs)
finally:
if remove_attn_wrapper_key:
del kwargs["_inside_attn_wrapper"]
return wrapper
@wrap_attn
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
@@ -202,8 +159,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
)
return out
@wrap_attn
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision, query.dtype)
if skip_reshape:
@@ -273,8 +230,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states
@wrap_attn
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
@@ -403,8 +359,7 @@ try:
except:
pass
@wrap_attn
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
b = q.shape[0]
dim_head = q.shape[-1]
# check to make sure xformers isn't broken
@@ -419,7 +374,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
disabled_xformers = True
if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
if skip_reshape:
# b h k d -> b k h d
@@ -472,8 +427,8 @@ else:
#TODO: other GPUs ?
SDP_BATCH_LIMIT = 2**31
@wrap_attn
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@@ -515,8 +470,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout = "HND"
@@ -546,7 +501,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
lambda t: t.transpose(1, 2),
(q, k, v),
)
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
if tensor_layout == "HND":
if not skip_output_reshape:
@@ -579,8 +534,8 @@ except AttributeError as error:
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
@wrap_attn
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@@ -642,19 +597,6 @@ else:
optimized_attention_masked = optimized_attention
# register core-supported attention functions
if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage)
if FLASH_ATTENTION_IS_AVAILABLE:
register_attention_function("flash", attention_flash)
if model_management.xformers_enabled():
register_attention_function("xformers", attention_xformers)
register_attention_function("pytorch", attention_pytorch)
register_attention_function("sub_quad", attention_sub_quad)
register_attention_function("split", attention_split)
def optimized_attention_for_device(device, mask=False, small_input=False):
if small_input:
if model_management.pytorch_attention_enabled():
@@ -687,7 +629,7 @@ class CrossAttention(nn.Module):
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, value=None, mask=None, transformer_options={}):
def forward(self, x, context=None, value=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
@@ -698,9 +640,9 @@ class CrossAttention(nn.Module):
v = self.to_v(context)
if mask is None:
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
else:
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
return self.to_out(out)
@@ -804,7 +746,7 @@ class BasicTransformerBlock(nn.Module):
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
n = self.attn1.to_out(n)
else:
n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=transformer_options)
n = self.attn1(n, context=context_attn1, value=value_attn1)
if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"]
@@ -844,7 +786,7 @@ class BasicTransformerBlock(nn.Module):
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
n = self.attn2.to_out(n)
else:
n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=transformer_options)
n = self.attn2(n, context=context_attn2, value=value_attn2)
if "attn2_output_patch" in transformer_patches:
patch = transformer_patches["attn2_output_patch"]
@@ -1075,7 +1017,7 @@ class SpatialVideoTransformer(SpatialTransformer):
B, S, C = x_mix.shape
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
x_mix = mix_block(x_mix, context=time_context, transformer_options=transformer_options)
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
x_mix = rearrange(
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
)

View File

@@ -606,7 +606,7 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
return _block_mixing(*args, **kwargs)
def _block_mixing(context, x, context_block, x_block, c, transformer_options={}):
def _block_mixing(context, x, context_block, x_block, c):
context_qkv, context_intermediates = context_block.pre_attention(context, c)
if x_block.x_block_self_attn:
@@ -622,7 +622,6 @@ def _block_mixing(context, x, context_block, x_block, c, transformer_options={})
attn = optimized_attention(
qkv[0], qkv[1], qkv[2],
heads=x_block.attn.num_heads,
transformer_options=transformer_options,
)
context_attn, x_attn = (
attn[:, : context_qkv[0].shape[1]],
@@ -638,7 +637,6 @@ def _block_mixing(context, x, context_block, x_block, c, transformer_options={})
attn2 = optimized_attention(
x_qkv2[0], x_qkv2[1], x_qkv2[2],
heads=x_block.attn2.num_heads,
transformer_options=transformer_options,
)
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
else:
@@ -960,10 +958,10 @@ class MMDiT(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"], transformer_options=args["transformer_options"])
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod, "transformer_options": transformer_options}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
context = out["txt"]
x = out["img"]
else:
@@ -972,7 +970,6 @@ class MMDiT(nn.Module):
x,
c=c_mod,
use_checkpoint=self.use_checkpoint,
transformer_options=transformer_options,
)
if control is not None:
control_o = control.get("output")

View File

@@ -120,7 +120,7 @@ class Attention(nn.Module):
nn.Dropout(0.0)
)
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
query = self.to_q(hidden_states)
@@ -146,7 +146,7 @@ class Attention(nn.Module):
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
hidden_states = self.to_out[0](hidden_states)
return hidden_states
@@ -182,16 +182,16 @@ class OmniGen2TransformerBlock(nn.Module):
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.modulation:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
else:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
hidden_states = hidden_states + self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
@@ -390,7 +390,7 @@ class OmniGen2Transformer2DModel(nn.Module):
ref_img_sizes, img_sizes,
)
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, transformer_options={}):
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
batch_size = len(hidden_states)
hidden_states = self.x_embedder(hidden_states)
@@ -405,17 +405,17 @@ class OmniGen2Transformer2DModel(nn.Module):
shift += ref_img_len
for layer in self.noise_refiner:
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb, transformer_options=transformer_options)
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
if ref_image_hidden_states is not None:
for layer in self.ref_image_refiner:
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb, transformer_options=transformer_options)
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
return hidden_states
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs):
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
B, C, H, W = x.shape
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
_, _, H_padded, W_padded = hidden_states.shape
@@ -444,7 +444,7 @@ class OmniGen2Transformer2DModel(nn.Module):
)
for layer in self.context_refiner:
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options)
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
img_len = hidden_states.shape[1]
combined_img_hidden_states = self.img_patch_embed_and_refine(
@@ -453,14 +453,13 @@ class OmniGen2Transformer2DModel(nn.Module):
noise_rotary_emb, ref_img_rotary_emb,
l_effective_ref_img_len, l_effective_img_len,
temb,
transformer_options=transformer_options,
)
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
attention_mask = None
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb, transformer_options=transformer_options)
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
hidden_states = self.norm_out(hidden_states, temb)

View File

@@ -132,7 +132,6 @@ class Attention(nn.Module):
encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
seq_txt = encoder_hidden_states.shape[1]
@@ -160,7 +159,7 @@ class Attention(nn.Module):
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
@@ -227,7 +226,6 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb)
txt_mod_params = self.txt_mod(temb)
@@ -244,7 +242,6 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states=txt_modulated,
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
hidden_states = hidden_states + img_gate1 * img_attn_output
@@ -437,9 +434,9 @@ class QwenImageTransformer2DModel(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
else:
@@ -449,12 +446,11 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options})
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]

View File

@@ -52,7 +52,7 @@ class WanSelfAttention(nn.Module):
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, freqs, transformer_options={}):
def forward(self, x, freqs):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
@@ -75,7 +75,6 @@ class WanSelfAttention(nn.Module):
k.view(b, s, n * d),
v,
heads=self.num_heads,
transformer_options=transformer_options,
)
x = self.o(x)
@@ -84,7 +83,7 @@ class WanSelfAttention(nn.Module):
class WanT2VCrossAttention(WanSelfAttention):
def forward(self, x, context, transformer_options={}, **kwargs):
def forward(self, x, context, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@@ -96,7 +95,7 @@ class WanT2VCrossAttention(WanSelfAttention):
v = self.v(context)
# compute attention
x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
x = optimized_attention(q, k, v, heads=self.num_heads)
x = self.o(x)
return x
@@ -117,7 +116,7 @@ class WanI2VCrossAttention(WanSelfAttention):
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, context, context_img_len, transformer_options={}):
def forward(self, x, context, context_img_len):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@@ -132,9 +131,9 @@ class WanI2VCrossAttention(WanSelfAttention):
v = self.v(context)
k_img = self.norm_k_img(self.k_img(context_img))
v_img = self.v_img(context_img)
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options)
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads)
# compute attention
x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
x = optimized_attention(q, k, v, heads=self.num_heads)
# output
x = x + img_x
@@ -207,7 +206,6 @@ class WanAttentionBlock(nn.Module):
freqs,
context,
context_img_len=257,
transformer_options={},
):
r"""
Args:
@@ -226,12 +224,12 @@ class WanAttentionBlock(nn.Module):
# self-attention
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, transformer_options=transformer_options)
freqs)
x = torch.addcmul(x, y, repeat_e(e[2], x))
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
return x
@@ -561,12 +559,12 @@ class WanModel(torch.nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
# head
x = self.head(x, e)
@@ -744,17 +742,17 @@ class VaceWanModel(WanModel):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
ii = self.vace_layers_mapping.get(i, None)
if ii is not None:
for iii in range(len(c)):
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x += c_skip * vace_strength[iii]
del c_skip
# head
@@ -843,12 +841,12 @@ class CameraWanModel(WanModel):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
# head
x = self.head(x, e)

View File

@@ -260,6 +260,10 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
for k in sdk:
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
if k.endswith(".weight") and ".linear1." in k:
key_map["{}".format(k.replace(".linear1.weight", ".linear1_qkv"))] = (k, (0, 0, hidden_size * 3))
if isinstance(model, comfy.model_base.GenmoMochi):
for k in sdk:

View File

@@ -15,10 +15,29 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
def convert_lora_wan_fun(sd): #Wan Fun loras
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
def convert_uso_lora(sd):
sd_out = {}
for k in sd:
tensor = sd[k]
k_to = "diffusion_model.{}".format(k.replace(".down.weight", ".lora_down.weight")
.replace(".up.weight", ".lora_up.weight")
.replace(".qkv_lora2.", ".txt_attn.qkv.")
.replace(".qkv_lora1.", ".img_attn.qkv.")
.replace(".proj_lora1.", ".img_attn.proj.")
.replace(".proj_lora2.", ".txt_attn.proj.")
.replace(".qkv_lora.", ".linear1_qkv.")
.replace(".proj_lora.", ".linear2.")
.replace(".processor.", ".")
)
sd_out[k_to] = tensor
return sd_out
def convert_lora(sd):
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
return convert_lora_bfl_control(sd)
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
return convert_lora_wan_fun(sd)
if "single_blocks.37.processor.qkv_lora.up.weight" in sd and "double_blocks.18.processor.qkv_lora2.up.weight" in sd:
return convert_uso_lora(sd)
return sd

View File

@@ -433,6 +433,9 @@ class ModelPatcher:
def set_model_double_block_patch(self, patch):
self.set_model_patch(patch, "double_block")
def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input")
def add_object_patch(self, name, obj):
self.object_patches[name] = obj

View File

@@ -52,6 +52,9 @@ except (ModuleNotFoundError, TypeError):
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
torch.backends.cudnn.benchmark = True
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)

View File

@@ -951,7 +951,11 @@ class MagicPrompt2(str, Enum):
class StyleType1(str, Enum):
AUTO = 'AUTO'
GENERAL = 'GENERAL'
REALISTIC = 'REALISTIC'
DESIGN = 'DESIGN'
FICTION = 'FICTION'
class ImagenImageGenerationInstance(BaseModel):
@@ -2676,7 +2680,7 @@ class ReleaseNote(BaseModel):
class RenderingSpeed(str, Enum):
BALANCED = 'BALANCED'
DEFAULT = 'DEFAULT'
TURBO = 'TURBO'
QUALITY = 'QUALITY'
@@ -4918,6 +4922,14 @@ class IdeogramV3EditRequest(BaseModel):
None,
description='A set of images to use as style references (maximum total size 10MB across all style references). The images should be in JPEG, PNG or WebP format.',
)
character_reference_images: Optional[List[str]] = Field(
None,
description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.'
)
character_reference_images_mask: Optional[List[str]] = Field(
None,
description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.'
)
class IdeogramV3Request(BaseModel):
@@ -4951,6 +4963,14 @@ class IdeogramV3Request(BaseModel):
style_type: Optional[StyleType1] = Field(
None, description='The type of style to apply'
)
character_reference_images: Optional[List[str]] = Field(
None,
description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.'
)
character_reference_images_mask: Optional[List[str]] = Field(
None,
description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.'
)
class ImagenGenerateImageResponse(BaseModel):

View File

@@ -0,0 +1,336 @@
import logging
from enum import Enum
from typing import Optional
from typing_extensions import override
import torch
from pydantic import BaseModel, Field
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api_nodes.util.validation_utils import (
validate_image_aspect_ratio_range,
get_number_of_images,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
)
from comfy_api_nodes.apinode_utils import download_url_to_image_tensor, upload_images_to_comfyapi, validate_string
BYTEPLUS_ENDPOINT = "/proxy/byteplus/api/v3/images/generations"
class Text2ImageModelName(str, Enum):
seedream3 = "seedream-3-0-t2i-250415"
class Image2ImageModelName(str, Enum):
seededit3 = "seededit-3-0-i2i-250628"
class Text2ImageTaskCreationRequest(BaseModel):
model: Text2ImageModelName = Text2ImageModelName.seedream3
prompt: str = Field(...)
response_format: Optional[str] = Field("url")
size: Optional[str] = Field(None)
seed: Optional[int] = Field(0, ge=0, le=2147483647)
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
watermark: Optional[bool] = Field(True)
class Image2ImageTaskCreationRequest(BaseModel):
model: Image2ImageModelName = Image2ImageModelName.seededit3
prompt: str = Field(...)
response_format: Optional[str] = Field("url")
image: str = Field(..., description="Base64 encoded string or image URL")
size: Optional[str] = Field("adaptive")
seed: Optional[int] = Field(..., ge=0, le=2147483647)
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
watermark: Optional[bool] = Field(True)
class ImageTaskCreationResponse(BaseModel):
model: str = Field(...)
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
data: list = Field([], description="Contains information about the generated image(s).")
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
RECOMMENDED_PRESETS = [
("1024x1024 (1:1)", 1024, 1024),
("864x1152 (3:4)", 864, 1152),
("1152x864 (4:3)", 1152, 864),
("1280x720 (16:9)", 1280, 720),
("720x1280 (9:16)", 720, 1280),
("832x1248 (2:3)", 832, 1248),
("1248x832 (3:2)", 1248, 832),
("1512x648 (21:9)", 1512, 648),
("2048x2048 (1:1)", 2048, 2048),
("Custom", None, None),
]
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
if response.error:
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
logging.info(error_msg)
raise RuntimeError(error_msg)
logging.info("ByteDance task succeeded, image URL: %s", response.data[0]["url"])
return response.data[0]["url"]
class ByteDanceImageNode(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="ByteDanceImageNode",
display_name="ByteDance Image",
category="api node/image/ByteDance",
description="Generate images using ByteDance models via api based on prompt",
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in Text2ImageModelName],
default=Text2ImageModelName.seedream3.value,
tooltip="Model name",
),
comfy_io.String.Input(
"prompt",
multiline=True,
tooltip="The text prompt used to generate the image",
),
comfy_io.Combo.Input(
"size_preset",
options=[label for label, _, _ in RECOMMENDED_PRESETS],
tooltip="Pick a recommended size. Select Custom to use the width and height below",
),
comfy_io.Int.Input(
"width",
default=1024,
min=512,
max=2048,
step=64,
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
),
comfy_io.Int.Input(
"height",
default=1024,
min=512,
max=2048,
step=64,
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to use for generation",
optional=True,
),
comfy_io.Float.Input(
"guidance_scale",
default=2.5,
min=1.0,
max=10.0,
step=0.01,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Higher value makes the image follow the prompt more closely",
optional=True,
),
comfy_io.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the image",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
size_preset: str,
width: int,
height: int,
seed: int,
guidance_scale: float,
watermark: bool,
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
w = h = None
for label, tw, th in RECOMMENDED_PRESETS:
if label == size_preset:
w, h = tw, th
break
if w is None or h is None:
w, h = width, height
if not (512 <= w <= 2048) or not (512 <= h <= 2048):
raise ValueError(
f"Custom size out of range: {w}x{h}. "
"Both width and height must be between 512 and 2048 pixels."
)
payload = Text2ImageTaskCreationRequest(
model=model,
prompt=prompt,
size=f"{w}x{h}",
seed=seed,
guidance_scale=guidance_scale,
watermark=watermark,
)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
response = await SynchronousOperation(
endpoint=ApiEndpoint(
path=BYTEPLUS_ENDPOINT,
method=HttpMethod.POST,
request_model=Text2ImageTaskCreationRequest,
response_model=ImageTaskCreationResponse,
),
request=payload,
auth_kwargs=auth_kwargs,
).execute()
return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
class ByteDanceImageEditNode(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="ByteDanceImageEditNode",
display_name="ByteDance Image Edit",
category="api node/video/ByteDance",
description="Edit images using ByteDance models via api based on prompt",
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in Image2ImageModelName],
default=Image2ImageModelName.seededit3.value,
tooltip="Model name",
),
comfy_io.Image.Input(
"image",
tooltip="The base image to edit",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Instruction to edit image",
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to use for generation",
optional=True,
),
comfy_io.Float.Input(
"guidance_scale",
default=5.5,
min=1.0,
max=10.0,
step=0.01,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Higher value makes the image follow the prompt more closely",
optional=True,
),
comfy_io.Boolean.Input(
"watermark",
default=True,
tooltip="Whether to add an \"AI generated\" watermark to the image",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
image: torch.Tensor,
prompt: str,
seed: int,
guidance_scale: float,
watermark: bool,
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.")
validate_image_aspect_ratio_range(image, (1, 3), (3, 1))
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
source_url = (await upload_images_to_comfyapi(
image,
max_images=1,
mime_type="image/png",
auth_kwargs=auth_kwargs,
))[0]
payload = Image2ImageTaskCreationRequest(
model=model,
prompt=prompt,
image=source_url,
seed=seed,
guidance_scale=guidance_scale,
watermark=watermark,
)
response = await SynchronousOperation(
endpoint=ApiEndpoint(
path=BYTEPLUS_ENDPOINT,
method=HttpMethod.POST,
request_model=Image2ImageTaskCreationRequest,
response_model=ImageTaskCreationResponse,
),
request=payload,
auth_kwargs=auth_kwargs,
).execute()
return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
class ByteDanceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
ByteDanceImageNode,
ByteDanceImageEditNode,
]
async def comfy_entrypoint() -> ByteDanceExtension:
return ByteDanceExtension()

View File

@@ -255,6 +255,7 @@ class IdeogramV1(comfy_io.ComfyNode):
display_name="Ideogram V1",
category="api node/image/Ideogram",
description="Generates images using the Ideogram V1 model.",
is_api_node=True,
inputs=[
comfy_io.String.Input(
"prompt",
@@ -383,6 +384,7 @@ class IdeogramV2(comfy_io.ComfyNode):
display_name="Ideogram V2",
category="api node/image/Ideogram",
description="Generates images using the Ideogram V2 model.",
is_api_node=True,
inputs=[
comfy_io.String.Input(
"prompt",
@@ -552,6 +554,7 @@ class IdeogramV3(comfy_io.ComfyNode):
category="api node/image/Ideogram",
description="Generates images using the Ideogram V3 model. "
"Supports both regular image generation from text prompts and image editing with mask.",
is_api_node=True,
inputs=[
comfy_io.String.Input(
"prompt",
@@ -612,11 +615,21 @@ class IdeogramV3(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"rendering_speed",
options=["BALANCED", "TURBO", "QUALITY"],
default="BALANCED",
options=["DEFAULT", "TURBO", "QUALITY"],
default="DEFAULT",
tooltip="Controls the trade-off between generation speed and quality",
optional=True,
),
comfy_io.Image.Input(
"character_image",
tooltip="Image to use as character reference.",
optional=True,
),
comfy_io.Mask.Input(
"character_mask",
tooltip="Optional mask for character reference image.",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
@@ -639,12 +652,46 @@ class IdeogramV3(comfy_io.ComfyNode):
magic_prompt_option="AUTO",
seed=0,
num_images=1,
rendering_speed="BALANCED",
rendering_speed="DEFAULT",
character_image=None,
character_mask=None,
):
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
if rendering_speed == "BALANCED": # for backward compatibility
rendering_speed = "DEFAULT"
character_img_binary = None
character_mask_binary = None
if character_image is not None:
input_tensor = character_image.squeeze().cpu()
if character_mask is not None:
character_mask = resize_mask_to_image(character_mask, character_image, allow_gradient=False)
character_mask = 1.0 - character_mask
if character_mask.shape[1:] != character_image.shape[1:-1]:
raise Exception("Character mask and image must be the same size")
mask_np = (character_mask.squeeze().cpu().numpy() * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_np)
mask_byte_arr = BytesIO()
mask_img.save(mask_byte_arr, format="PNG")
mask_byte_arr.seek(0)
character_mask_binary = mask_byte_arr
character_mask_binary.name = "mask.png"
img_np = (input_tensor.numpy() * 255).astype(np.uint8)
img = Image.fromarray(img_np)
img_byte_arr = BytesIO()
img.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
character_img_binary = img_byte_arr
character_img_binary.name = "image.png"
elif character_mask is not None:
raise Exception("Character mask requires character image to be present")
# Check if both image and mask are provided for editing mode
if image is not None and mask is not None:
# Edit mode
@@ -693,6 +740,15 @@ class IdeogramV3(comfy_io.ComfyNode):
if num_images > 1:
edit_request.num_images = num_images
files = {
"image": img_binary,
"mask": mask_binary,
}
if character_img_binary:
files["character_reference_images"] = character_img_binary
if character_mask_binary:
files["character_mask_binary"] = character_mask_binary
# Execute the operation for edit mode
operation = SynchronousOperation(
endpoint=ApiEndpoint(
@@ -702,10 +758,7 @@ class IdeogramV3(comfy_io.ComfyNode):
response_model=IdeogramGenerateResponse,
),
request=edit_request,
files={
"image": img_binary,
"mask": mask_binary,
},
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
@@ -739,6 +792,14 @@ class IdeogramV3(comfy_io.ComfyNode):
if num_images > 1:
gen_request.num_images = num_images
files = {}
if character_img_binary:
files["character_reference_images"] = character_img_binary
if character_mask_binary:
files["character_mask_binary"] = character_mask_binary
if files:
gen_request.style_type = "AUTO"
# Execute the operation for generation mode
operation = SynchronousOperation(
endpoint=ApiEndpoint(
@@ -748,6 +809,8 @@ class IdeogramV3(comfy_io.ComfyNode):
response_model=IdeogramGenerateResponse,
),
request=gen_request,
files=files if files else None,
content_type="multipart/form-data",
auth_kwargs=auth,
)

View File

@@ -12,6 +12,7 @@ User Guides:
"""
from typing import Union, Optional, Any
from typing_extensions import override
from enum import Enum
import torch
@@ -46,9 +47,9 @@ from comfy_api_nodes.apinode_utils import (
validate_string,
download_url_to_image_tensor,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy_api.input_impl import VideoFromFile
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api_nodes.util.validation_utils import validate_image_dimensions, validate_image_aspect_ratio
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
@@ -85,20 +86,11 @@ class RunwayGen3aAspectRatio(str, Enum):
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the video URL from the task status response if it exists."""
if response.output and len(response.output) > 0:
if hasattr(response, "output") and len(response.output) > 0:
return response.output[0]
return None
# TODO: replace with updated image validation utils (upstream)
def validate_input_image(image: torch.Tensor) -> bool:
"""
Validate the input image is within the size limits for the Runway API.
See: https://docs.dev.runwayml.com/assets/inputs/#common-error-reasons
"""
return image.shape[2] < 8000 and image.shape[1] < 8000
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
@@ -134,458 +126,438 @@ def extract_progress_from_task_status(
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the image URL from the task status response if it exists."""
if response.output and len(response.output) > 0:
if hasattr(response, "output") and len(response.output) > 0:
return response.output[0]
return None
class RunwayVideoGenNode(ComfyNodeABC):
"""Runway Video Node Base."""
async def get_response(
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None, estimated_duration: Optional[int] = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=estimated_duration,
node_id=node_id,
)
RETURN_TYPES = ("VIDEO",)
FUNCTION = "api_call"
CATEGORY = "api node/video/Runway"
API_NODE = True
def validate_task_created(self, response: RunwayImageToVideoResponse) -> bool:
"""
Validate the task creation response from the Runway API matches
expected format.
"""
if not bool(response.id):
raise RunwayApiError("Invalid initial response from Runway API.")
return True
async def generate_video(
request: RunwayImageToVideoRequest,
auth_kwargs: dict[str, str],
node_id: Optional[str] = None,
estimated_duration: Optional[int] = None,
) -> VideoFromFile:
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=RunwayImageToVideoRequest,
response_model=RunwayImageToVideoResponse,
),
request=request,
auth_kwargs=auth_kwargs,
)
def validate_response(self, response: RunwayImageToVideoResponse) -> bool:
"""
Validate the successful task status response from the Runway API
matches expected format.
"""
if not response.output or len(response.output) == 0:
raise RunwayApiError(
"Runway task succeeded but no video data found in response."
)
return True
initial_response = await initial_operation.execute()
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> RunwayImageToVideoResponse:
"""Poll the task status until it is finished then get the response."""
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
node_id=node_id,
final_response = await get_response(initial_response.id, auth_kwargs, node_id, estimated_duration)
if not final_response.output:
raise RunwayApiError("Runway task succeeded but no video data found in response.")
video_url = get_video_url_from_task_status(final_response)
return await download_url_to_video_output(video_url)
class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="RunwayImageToVideoNodeGen3a",
display_name="Runway Image to Video (Gen3a Turbo)",
category="api node/video/Runway",
description="Generate a video from a single starting frame using Gen3a Turbo model. "
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the generation",
),
comfy_io.Image.Input(
"start_frame",
tooltip="Start frame to be used for the video",
),
comfy_io.Combo.Input(
"duration",
options=[model.value for model in Duration],
),
comfy_io.Combo.Input(
"ratio",
options=[model.value for model in RunwayGen3aAspectRatio],
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967295,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Random seed for generation",
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
async def generate_video(
self,
request: RunwayImageToVideoRequest,
auth_kwargs: dict[str, str],
node_id: Optional[str] = None,
) -> tuple[VideoFromFile]:
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=RunwayImageToVideoRequest,
response_model=RunwayImageToVideoResponse,
),
request=request,
@classmethod
async def execute(
cls,
prompt: str,
start_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
) -> comfy_io.NodeOutput:
validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
download_urls = await upload_images_to_comfyapi(
start_frame,
max_images=1,
mime_type="image/png",
auth_kwargs=auth_kwargs,
)
initial_response = await initial_operation.execute()
self.validate_task_created(initial_response)
task_id = initial_response.id
final_response = await self.get_response(task_id, auth_kwargs, node_id)
self.validate_response(final_response)
video_url = get_video_url_from_task_status(final_response)
return (await download_url_to_video_output(video_url),)
return comfy_io.NodeOutput(
await generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen3a_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
),
),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
)
)
class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
"""Runway Image to Video Node using Gen3a Turbo model."""
DESCRIPTION = "Generate a video from a single starting frame using Gen3a Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo."
class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
def define_schema(cls):
return comfy_io.Schema(
node_id="RunwayImageToVideoNodeGen4",
display_name="Runway Image to Video (Gen4 Turbo)",
category="api node/video/Runway",
description="Generate a video from a single starting frame using Gen4 Turbo model. "
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the generation",
),
"start_frame": (
IO.IMAGE,
{"tooltip": "Start frame to be used for the video"},
comfy_io.Image.Input(
"start_frame",
tooltip="Start frame to be used for the video",
),
"duration": model_field_to_node_input(
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
comfy_io.Combo.Input(
"duration",
options=[model.value for model in Duration],
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayImageToVideoRequest,
comfy_io.Combo.Input(
"ratio",
enum_type=RunwayGen3aAspectRatio,
options=[model.value for model in RunwayGen4TurboAspectRatio],
),
"seed": model_field_to_node_input(
IO.INT,
RunwayImageToVideoRequest,
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967295,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Random seed for generation",
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
async def api_call(
self,
@classmethod
async def execute(
cls,
prompt: str,
start_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Validate inputs
) -> comfy_io.NodeOutput:
validate_string(prompt, min_length=1)
validate_input_image(start_frame)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# Upload image
download_urls = await upload_images_to_comfyapi(
start_frame,
max_images=1,
mime_type="image/png",
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return await self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen3a_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
return comfy_io.NodeOutput(
await generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen4_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
),
),
),
auth_kwargs=kwargs,
node_id=unique_id,
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
)
)
class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
"""Runway Image to Video Node using Gen4 Turbo model."""
DESCRIPTION = "Generate a video from a single starting frame using Gen4 Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video."
class RunwayFirstLastFrameNode(comfy_io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
def define_schema(cls):
return comfy_io.Schema(
node_id="RunwayFirstLastFrameNode",
display_name="Runway First-Last-Frame to Video",
category="api node/video/Runway",
description="Upload first and last keyframes, draft a prompt, and generate a video. "
"More complex transitions, such as cases where the Last frame is completely different "
"from the First frame, may benefit from the longer 10s duration. "
"This would give the generation more time to smoothly transition between the two inputs. "
"Before diving in, review these best practices to ensure that your input selections "
"will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the generation",
),
"start_frame": (
IO.IMAGE,
{"tooltip": "Start frame to be used for the video"},
comfy_io.Image.Input(
"start_frame",
tooltip="Start frame to be used for the video",
),
"duration": model_field_to_node_input(
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
comfy_io.Image.Input(
"end_frame",
tooltip="End frame to be used for the video. Supported for gen3a_turbo only.",
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayImageToVideoRequest,
comfy_io.Combo.Input(
"duration",
options=[model.value for model in Duration],
),
comfy_io.Combo.Input(
"ratio",
enum_type=RunwayGen4TurboAspectRatio,
options=[model.value for model in RunwayGen3aAspectRatio],
),
"seed": model_field_to_node_input(
IO.INT,
RunwayImageToVideoRequest,
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967295,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Random seed for generation",
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
prompt: str,
start_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Validate inputs
validate_string(prompt, min_length=1)
validate_input_image(start_frame)
# Upload image
download_urls = await upload_images_to_comfyapi(
start_frame,
max_images=1,
mime_type="image/png",
auth_kwargs=kwargs,
)
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return await self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen4_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
),
),
auth_kwargs=kwargs,
node_id=unique_id,
)
class RunwayFirstLastFrameNode(RunwayVideoGenNode):
"""Runway First-Last Frame Node."""
DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3."
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> RunwayImageToVideoResponse:
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
node_id=node_id,
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
),
"start_frame": (
IO.IMAGE,
{"tooltip": "Start frame to be used for the video"},
),
"end_frame": (
IO.IMAGE,
{
"tooltip": "End frame to be used for the video. Supported for gen3a_turbo only."
},
),
"duration": model_field_to_node_input(
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayImageToVideoRequest,
"ratio",
enum_type=RunwayGen3aAspectRatio,
),
"seed": model_field_to_node_input(
IO.INT,
RunwayImageToVideoRequest,
"seed",
control_after_generate=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"unique_id": "UNIQUE_ID",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
async def api_call(
self,
async def execute(
cls,
prompt: str,
start_frame: torch.Tensor,
end_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Validate inputs
) -> comfy_io.NodeOutput:
validate_string(prompt, min_length=1)
validate_input_image(start_frame)
validate_input_image(end_frame)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# Upload images
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
download_urls = await upload_images_to_comfyapi(
stacked_input_images,
max_images=2,
mime_type="image/png",
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
if len(download_urls) != 2:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return await self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen3a_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
),
RunwayPromptImageDetailedObject(
uri=str(download_urls[1]), position="last"
),
]
return comfy_io.NodeOutput(
await generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen3a_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
),
RunwayPromptImageDetailedObject(
uri=str(download_urls[1]), position="last"
),
]
),
),
),
auth_kwargs=kwargs,
node_id=unique_id,
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
)
)
class RunwayTextToImageNode(ComfyNodeABC):
"""Runway Text to Image Node."""
RETURN_TYPES = ("IMAGE",)
FUNCTION = "api_call"
CATEGORY = "api node/image/Runway"
API_NODE = True
DESCRIPTION = "Generate an image from a text prompt using Runway's Gen 4 model. You can also include reference images to guide the generation."
class RunwayTextToImageNode(comfy_io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayTextToImageRequest, "promptText", multiline=True
def define_schema(cls):
return comfy_io.Schema(
node_id="RunwayTextToImageNode",
display_name="Runway Text to Image",
category="api node/image/Runway",
description="Generate an image from a text prompt using Runway's Gen 4 model. "
"You can also include reference image to guide the generation.",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the generation",
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayTextToImageRequest,
comfy_io.Combo.Input(
"ratio",
enum_type=RunwayTextToImageAspectRatioEnum,
options=[model.value for model in RunwayTextToImageAspectRatioEnum],
),
},
"optional": {
"reference_image": (
IO.IMAGE,
{"tooltip": "Optional reference image to guide the generation"},
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
def validate_task_created(self, response: RunwayTextToImageResponse) -> bool:
"""
Validate the task creation response from the Runway API matches
expected format.
"""
if not bool(response.id):
raise RunwayApiError("Invalid initial response from Runway API.")
return True
def validate_response(self, response: TaskStatusResponse) -> bool:
"""
Validate the successful task status response from the Runway API
matches expected format.
"""
if not response.output or len(response.output) == 0:
raise RunwayApiError(
"Runway task succeeded but no image data found in response."
)
return True
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
node_id=node_id,
comfy_io.Image.Input(
"reference_image",
tooltip="Optional reference image to guide the generation",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
async def api_call(
self,
@classmethod
async def execute(
cls,
prompt: str,
ratio: str,
reference_image: Optional[torch.Tensor] = None,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[torch.Tensor]:
# Validate inputs
) -> comfy_io.NodeOutput:
validate_string(prompt, min_length=1)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# Prepare reference images if provided
reference_images = None
if reference_image is not None:
validate_input_image(reference_image)
validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
download_urls = await upload_images_to_comfyapi(
reference_image,
max_images=1,
mime_type="image/png",
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload reference image to comfy api.")
reference_images = [ReferenceImage(uri=str(download_urls[0]))]
# Create request
request = RunwayTextToImageRequest(
promptText=prompt,
model=Model4.gen4_image,
@@ -593,7 +565,6 @@ class RunwayTextToImageNode(ComfyNodeABC):
referenceImages=reference_images,
)
# Execute initial request
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_TEXT_TO_IMAGE,
@@ -602,34 +573,33 @@ class RunwayTextToImageNode(ComfyNodeABC):
response_model=RunwayTextToImageResponse,
),
request=request,
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
initial_response = await initial_operation.execute()
self.validate_task_created(initial_response)
task_id = initial_response.id
# Poll for completion
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
final_response = await get_response(
initial_response.id,
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
)
self.validate_response(final_response)
if not final_response.output:
raise RunwayApiError("Runway task succeeded but no image data found in response.")
# Download and return image
image_url = get_image_url_from_task_status(final_response)
return (await download_url_to_image_tensor(image_url),)
return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
NODE_CLASS_MAPPINGS = {
"RunwayFirstLastFrameNode": RunwayFirstLastFrameNode,
"RunwayImageToVideoNodeGen3a": RunwayImageToVideoNodeGen3a,
"RunwayImageToVideoNodeGen4": RunwayImageToVideoNodeGen4,
"RunwayTextToImageNode": RunwayTextToImageNode,
}
class RunwayExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
RunwayFirstLastFrameNode,
RunwayImageToVideoNodeGen3a,
RunwayImageToVideoNodeGen4,
RunwayTextToImageNode,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"RunwayFirstLastFrameNode": "Runway First-Last-Frame to Video",
"RunwayImageToVideoNodeGen3a": "Runway Image to Video (Gen3a Turbo)",
"RunwayImageToVideoNodeGen4": "Runway Image to Video (Gen4 Turbo)",
"RunwayTextToImageNode": "Runway Text to Image",
}
async def comfy_entrypoint() -> RunwayExtension:
return RunwayExtension()

View File

@@ -1,5 +1,8 @@
from inspect import cleandoc
from comfy.comfy_types.node_typing import IO
from typing import Optional
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api_nodes.apis.stability_api import (
StabilityUpscaleConservativeRequest,
StabilityUpscaleCreativeRequest,
@@ -46,87 +49,94 @@ def get_async_dummy_status(x: StabilityResultsGetResponse):
return StabilityPollStatus.in_progress
class StabilityStableImageUltraNode:
class StabilityStableImageUltraNode(comfy_io.ComfyNode):
"""
Generates images synchronously based on prompt and resolution.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Stability AI"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
"What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
def define_schema(cls):
return comfy_io.Schema(
node_id="StabilityStableImageUltraNode",
display_name="Stability AI Stable Image Ultra",
category="api node/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
"elements, colors, and subjects will lead to better results. " +
"To control the weight of a given word use the format `(word:weight)`," +
"where `word` is the word you'd like to control the weight of and `weight`" +
"is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" +
"would convey a sky that was blue and green, but more green than blue."
},
"would convey a sky that was blue and green, but more green than blue.",
),
"aspect_ratio": ([x.value for x in StabilityAspectRatio],
{
"default": StabilityAspectRatio.ratio_1_1,
"tooltip": "Aspect ratio of generated image.",
},
comfy_io.Combo.Input(
"aspect_ratio",
options=[x.value for x in StabilityAspectRatio],
default=StabilityAspectRatio.ratio_1_1.value,
tooltip="Aspect ratio of generated image.",
),
"style_preset": (get_stability_style_presets(),
{
"tooltip": "Optional desired style of generated image.",
},
comfy_io.Combo.Input(
"style_preset",
options=get_stability_style_presets(),
tooltip="Optional desired style of generated image.",
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 4294967294,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
},
"optional": {
"image": (IO.IMAGE,),
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "A blurb of text describing what you do not wish to see in the output image. This is an advanced feature."
},
comfy_io.Image.Input(
"image",
optional=True,
),
"image_denoise": (
IO.FLOAT,
{
"default": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
},
comfy_io.String.Input(
"negative_prompt",
default="",
tooltip="A blurb of text describing what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
comfy_io.Float.Input(
"image_denoise",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
async def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
**kwargs):
@classmethod
async def execute(
cls,
prompt: str,
aspect_ratio: str,
style_preset: str,
seed: int,
image: Optional[torch.Tensor] = None,
negative_prompt: str = "",
image_denoise: Optional[float] = 0.5,
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False)
# prepare image binary if image present
image_binary = None
@@ -144,6 +154,11 @@ class StabilityStableImageUltraNode:
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/generate/ultra",
@@ -161,7 +176,7 @@ class StabilityStableImageUltraNode:
),
files=files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response_api = await operation.execute()
@@ -171,95 +186,106 @@ class StabilityStableImageUltraNode:
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return (returned_image,)
return comfy_io.NodeOutput(returned_image)
class StabilityStableImageSD_3_5Node:
class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
"""
Generates images synchronously based on prompt and resolution.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Stability AI"
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="StabilityStableImageSD_3_5Node",
display_name="Stability AI Stable Diffusion 3.5 Image",
category="api node/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
),
comfy_io.Combo.Input(
"model",
options=[x.value for x in Stability_SD3_5_Model],
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[x.value for x in StabilityAspectRatio],
default=StabilityAspectRatio.ratio_1_1.value,
tooltip="Aspect ratio of generated image.",
),
comfy_io.Combo.Input(
"style_preset",
options=get_stability_style_presets(),
tooltip="Optional desired style of generated image.",
),
comfy_io.Float.Input(
"cfg_scale",
default=4.0,
min=1.0,
max=10.0,
step=0.1,
tooltip="How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
comfy_io.Image.Input(
"image",
optional=True,
),
comfy_io.String.Input(
"negative_prompt",
default="",
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
),
comfy_io.Float.Input(
"image_denoise",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
},
),
"model": ([x.value for x in Stability_SD3_5_Model],),
"aspect_ratio": ([x.value for x in StabilityAspectRatio],
{
"default": StabilityAspectRatio.ratio_1_1,
"tooltip": "Aspect ratio of generated image.",
},
),
"style_preset": (get_stability_style_presets(),
{
"tooltip": "Optional desired style of generated image.",
},
),
"cfg_scale": (
IO.FLOAT,
{
"default": 4.0,
"min": 1.0,
"max": 10.0,
"step": 0.1,
"tooltip": "How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 4294967294,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"optional": {
"image": (IO.IMAGE,),
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
},
),
"image_denoise": (
IO.FLOAT,
{
"default": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
async def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
**kwargs):
async def execute(
cls,
model: str,
prompt: str,
aspect_ratio: str,
style_preset: str,
seed: int,
cfg_scale: float,
image: Optional[torch.Tensor] = None,
negative_prompt: str = "",
image_denoise: Optional[float] = 0.5,
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False)
# prepare image binary if image present
image_binary = None
@@ -280,6 +306,11 @@ class StabilityStableImageSD_3_5Node:
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/generate/sd3",
@@ -300,7 +331,7 @@ class StabilityStableImageSD_3_5Node:
),
files=files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response_api = await operation.execute()
@@ -310,72 +341,75 @@ class StabilityStableImageSD_3_5Node:
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return (returned_image,)
return comfy_io.NodeOutput(returned_image)
class StabilityUpscaleConservativeNode:
class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
"""
Upscale image with minimal alterations to 4K resolution.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Stability AI"
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="StabilityUpscaleConservativeNode",
display_name="Stability AI Upscale Conservative",
category="api node/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("image"),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
),
comfy_io.Float.Input(
"creativity",
default=0.35,
min=0.2,
max=0.5,
step=0.01,
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
comfy_io.String.Input(
"negative_prompt",
default="",
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
},
),
"creativity": (
IO.FLOAT,
{
"default": 0.35,
"min": 0.2,
"max": 0.5,
"step": 0.01,
"tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 4294967294,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
**kwargs):
async def execute(
cls,
image: torch.Tensor,
prompt: str,
creativity: float,
seed: int,
negative_prompt: str = "",
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
@@ -386,6 +420,11 @@ class StabilityUpscaleConservativeNode:
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
@@ -401,7 +440,7 @@ class StabilityUpscaleConservativeNode:
),
files=files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response_api = await operation.execute()
@@ -411,77 +450,81 @@ class StabilityUpscaleConservativeNode:
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return (returned_image,)
return comfy_io.NodeOutput(returned_image)
class StabilityUpscaleCreativeNode:
class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
"""
Upscale image with minimal alterations to 4K resolution.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Stability AI"
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="StabilityUpscaleCreativeNode",
display_name="Stability AI Upscale Creative",
category="api node/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("image"),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
),
comfy_io.Float.Input(
"creativity",
default=0.3,
min=0.1,
max=0.5,
step=0.01,
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
),
comfy_io.Combo.Input(
"style_preset",
options=get_stability_style_presets(),
tooltip="Optional desired style of generated image.",
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967294,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
comfy_io.String.Input(
"negative_prompt",
default="",
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
force_input=True,
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
},
),
"creativity": (
IO.FLOAT,
{
"default": 0.3,
"min": 0.1,
"max": 0.5,
"step": 0.01,
"tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.",
},
),
"style_preset": (get_stability_style_presets(),
{
"tooltip": "Optional desired style of generated image.",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 4294967294,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
**kwargs):
async def execute(
cls,
image: torch.Tensor,
prompt: str,
creativity: float,
style_preset: str,
seed: int,
negative_prompt: str = "",
) -> comfy_io.NodeOutput:
validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
@@ -494,6 +537,11 @@ class StabilityUpscaleCreativeNode:
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/creative",
@@ -510,7 +558,7 @@ class StabilityUpscaleCreativeNode:
),
files=files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response_api = await operation.execute()
@@ -525,7 +573,8 @@ class StabilityUpscaleCreativeNode:
completed_statuses=[StabilityPollStatus.finished],
failed_statuses=[StabilityPollStatus.failed],
status_extractor=lambda x: get_async_dummy_status(x),
auth_kwargs=kwargs,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
)
response_poll: StabilityResultsGetResponse = await operation.execute()
@@ -535,41 +584,48 @@ class StabilityUpscaleCreativeNode:
image_data = base64.b64decode(response_poll.result)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return (returned_image,)
return comfy_io.NodeOutput(returned_image)
class StabilityUpscaleFastNode:
class StabilityUpscaleFastNode(comfy_io.ComfyNode):
"""
Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Stability AI"
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="StabilityUpscaleFastNode",
display_name="Stability AI Upscale Fast",
category="api node/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input("image"),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
},
"optional": {
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
async def api_call(self, image: torch.Tensor, **kwargs):
async def execute(cls, image: torch.Tensor) -> comfy_io.NodeOutput:
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
files = {
"image": image_binary
}
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/fast",
@@ -580,7 +636,7 @@ class StabilityUpscaleFastNode:
request=EmptyRequest(),
files=files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response_api = await operation.execute()
@@ -590,24 +646,20 @@ class StabilityUpscaleFastNode:
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return (returned_image,)
return comfy_io.NodeOutput(returned_image)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"StabilityStableImageUltraNode": StabilityStableImageUltraNode,
"StabilityStableImageSD_3_5Node": StabilityStableImageSD_3_5Node,
"StabilityUpscaleConservativeNode": StabilityUpscaleConservativeNode,
"StabilityUpscaleCreativeNode": StabilityUpscaleCreativeNode,
"StabilityUpscaleFastNode": StabilityUpscaleFastNode,
}
class StabilityExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
StabilityStableImageUltraNode,
StabilityStableImageSD_3_5Node,
StabilityUpscaleConservativeNode,
StabilityUpscaleCreativeNode,
StabilityUpscaleFastNode,
]
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"StabilityStableImageUltraNode": "Stability AI Stable Image Ultra",
"StabilityStableImageSD_3_5Node": "Stability AI Stable Diffusion 3.5 Image",
"StabilityUpscaleConservativeNode": "Stability AI Upscale Conservative",
"StabilityUpscaleCreativeNode": "Stability AI Upscale Creative",
"StabilityUpscaleFastNode": "Stability AI Upscale Fast",
}
async def comfy_entrypoint() -> StabilityExtension:
return StabilityExtension()

View File

@@ -1,6 +1,10 @@
#from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
import numpy as np
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def loglinear_interp(t_steps, num_steps):
"""
@@ -19,25 +23,30 @@ NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.694615152
"SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582],
"SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]}
class AlignYourStepsScheduler:
class AlignYourStepsScheduler(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model_type": (["SD1", "SDXL", "SVD"], ),
"steps": ("INT", {"default": 10, "min": 1, "max": 10000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="AlignYourStepsScheduler",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),
io.Int.Input("steps", default=10, min=1, max=10000),
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[io.Sigmas.Output()],
)
def get_sigmas(self, model_type, steps, denoise):
# Deprecated: use the V3 schema's `execute` method instead of this.
return AlignYourStepsScheduler().execute(model_type, steps, denoise).result
@classmethod
def execute(cls, model_type, steps, denoise) -> io.NodeOutput:
total_steps = steps
if denoise < 1.0:
if denoise <= 0.0:
return (torch.FloatTensor([]),)
return io.NodeOutput(torch.FloatTensor([]))
total_steps = round(steps * denoise)
sigmas = NOISE_LEVELS[model_type][:]
@@ -46,8 +55,15 @@ class AlignYourStepsScheduler:
sigmas = sigmas[-(total_steps + 1):]
sigmas[-1] = 0
return (torch.FloatTensor(sigmas), )
return io.NodeOutput(torch.FloatTensor(sigmas))
NODE_CLASS_MAPPINGS = {
"AlignYourStepsScheduler": AlignYourStepsScheduler,
}
class AlignYourStepsExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
AlignYourStepsScheduler,
]
async def comfy_entrypoint() -> AlignYourStepsExtension:
return AlignYourStepsExtension()

View File

@@ -1,503 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union
from comfy_api.latest import io, ComfyExtension
import comfy.patcher_extension
import logging
import torch
import math
import comfy.model_patcher
if TYPE_CHECKING:
from uuid import UUID
def easysortblock_predict_noise_wrapper(executor, *args, **kwargs):
# get values from args
x: torch.Tensor = args[0]
timestep: float = args[1]
model_options: dict[str] = args[2]
easycache: EasySortblockHolder = model_options["transformer_options"]["easycache"]
# initialize predict_ratios
if easycache.initial_step:
sample_sigmas = model_options["transformer_options"]["sample_sigmas"]
relevant_sigmas = []
for i,sigma in enumerate(sample_sigmas):
if easycache.check_if_within_timesteps(sigma):
relevant_sigmas.append((i, sigma))
start_index = relevant_sigmas[0][0]
end_index = relevant_sigmas[-1][0]
easycache.predict_ratios = torch.linspace(easycache.start_predict_ratio, easycache.end_predict_ratio, end_index - start_index + 1)
easycache.predict_start_index = start_index
easycache.skip_current_step = False
if easycache.is_past_end_timestep(timestep):
return executor(*args, **kwargs)
# prepare next x_prev
next_x_prev = x
input_change = None
do_easycache = easycache.should_do_easycache(timestep)
if do_easycache:
easycache.check_metadata(x)
if easycache.has_x_prev_subsampled():
if easycache.has_x_prev_subsampled():
input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
easycache.cumulative_change_rate += approx_output_change_rate
if easycache.cumulative_change_rate < easycache.reuse_threshold:
if easycache.verbose:
logging.info(f"EasySortblock [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
# other conds should also skip this step
easycache.skip_current_step = True
easycache.steps_skipped.append(easycache.step_count)
else:
if easycache.verbose:
logging.info(f"EasySortblock [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
easycache.cumulative_change_rate = 0.0
output: torch.Tensor = executor(*args, **kwargs)
if easycache.has_output_prev_norm():
output_change = (easycache.subsample(output, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
if easycache.verbose:
output_change_rate = output_change / easycache.output_prev_norm
easycache.output_change_rates.append(output_change_rate.item())
if easycache.has_relative_transformation_rate():
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
if easycache.verbose:
logging.info(f"EasySortblock [verbose] - approx_output_change_rate: {approx_output_change_rate}")
if input_change is not None:
easycache.relative_transformation_rate = output_change / input_change
if easycache.verbose:
logging.info(f"EasySortblock [verbose] - output_change_rate: {output_change_rate}")
easycache.x_prev_subsampled = easycache.subsample(next_x_prev)
easycache.output_prev_subsampled = easycache.subsample(output)
easycache.output_prev_norm = output.flatten().abs().mean()
if easycache.verbose:
logging.info(f"EasySortblock [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
# increment step count
easycache.step_count += 1
easycache.initial_step = False
return output
def easysortblock_outer_sample_wrapper(executor, *args, **kwargs):
"""
This OUTER_SAMPLE wrapper makes sure EasySortblock is prepped for current run, and all memory usage is cleared at the end.
"""
try:
guider = executor.class_obj
orig_model_options = guider.model_options
guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options)
# clone and prepare timesteps
guider.model_options["transformer_options"]["easycache"] = guider.model_options["transformer_options"]["easycache"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling)
easycache: EasySortblockHolder = guider.model_options['transformer_options']['easycache']
logging.info(f"{easycache.name} enabled - threshold: {easycache.reuse_threshold}, start_percent: {easycache.start_percent}, end_percent: {easycache.end_percent}")
return executor(*args, **kwargs)
finally:
easycache = guider.model_options['transformer_options']['easycache']
output_change_rates = easycache.output_change_rates
approx_output_change_rates = easycache.approx_output_change_rates
if easycache.verbose:
logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}")
logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}")
total_steps = len(args[3])-1
logging.info(f"{easycache.name} - skipped {len(easycache.steps_skipped)}/{total_steps} steps")# ({total_steps/(total_steps-easycache.total_steps_skipped):.2f}x speedup).")
logging.info(f"{easycache.name} - skipped steps: {easycache.steps_skipped}")
easycache.reset()
guider.model_options = orig_model_options
def model_forward_wrapper(executor, *args, **kwargs):
# TODO: make work with batches of conds
transformer_options: dict[str] = args[-1]
if not isinstance(transformer_options, dict):
transformer_options = kwargs.get("transformer_options")
if not transformer_options:
transformer_options = args[-2]
sigmas = transformer_options["sigmas"]
sb_holder: EasySortblockHolder = transformer_options["easycache"]
# if initial step, prepare everything for Sortblock
if sb_holder.initial_step:
logging.info(f"EasySortblock: inside model {executor.class_obj.__class__.__name__}")
# TODO: generalize for other models
# these won't stick around past this step; should store on sb_holder instead
logging.info(f"EasySortblock: preparing {len(executor.class_obj.double_blocks)} double blocks and {len(executor.class_obj.single_blocks)} single blocks")
if hasattr(executor.class_obj, "double_blocks"):
for block in executor.class_obj.double_blocks:
prepare_block(block, sb_holder)
if hasattr(executor.class_obj, "single_blocks"):
for block in executor.class_obj.single_blocks:
prepare_block(block, sb_holder)
if hasattr(executor.class_obj, "blocks"):
for block in executor.class_obj.block:
prepare_block(block, sb_holder)
if sb_holder.skip_current_step:
predict_index = max(0, sb_holder.step_count - sb_holder.predict_start_index)
predict_ratio = sb_holder.predict_ratios[predict_index]
logging.info(f"EasySortblock: skipping step {sb_holder.step_count}, predict_ratio: {predict_ratio}")
# reuse_ratio = 1.0 - predict_ratio
for block_type, blocks in sb_holder.blocks_per_type.items():
for block in blocks:
cache: BlockCache = block.__block_cache
cache.allowed_to_skip = False
sorted_blocks = sorted(blocks, key=lambda x: (x.__block_cache.consecutive_skipped_steps, x.__block_cache.prev_change_rate))
# for block in sorted_blocks:
# pass
threshold_index = int(len(sorted_blocks) * predict_ratio)
# blocks with lower similarity are marked for recomputation
for block in sorted_blocks[:threshold_index]:
cache: BlockCache = block.__block_cache
cache.allowed_to_skip = True
logging.info(f"EasySortblock: skip block {block.__class__.__name__} - consecutive_skipped_steps: {block.__block_cache.consecutive_skipped_steps}, prev_change_rate: {block.__block_cache.prev_change_rate}, index: {block.__block_cache.block_index}")
not_skipped = [block for block in blocks if not block.__block_cache.allowed_to_skip]
for block in not_skipped:
logging.info(f"EasySortblock: reco block {block.__class__.__name__} - consecutive_skipped_steps: {block.__block_cache.consecutive_skipped_steps}, prev_change_rate: {block.__block_cache.prev_change_rate}, index: {block.__block_cache.block_index}")
logging.info(f"EasySortblock: for {block_type}, selected {len(sorted_blocks[:threshold_index])} blocks for prediction and {len(sorted_blocks[threshold_index:])} blocks for recomputation")
# return executor(*args, **kwargs)
to_return = executor(*args, **kwargs)
return to_return
def block_forward_factory(func, block):
def block_forward_wrapper(*args, **kwargs):
transformer_options: dict[str] = kwargs.get("transformer_options")
sigmas = transformer_options["sigmas"]
sb_holder: EasySortblockHolder = transformer_options["easycache"]
cache: BlockCache = block.__block_cache
# make sure stream count is properly set for this block
if sb_holder.initial_step:
sb_holder.add_to_blocks_per_type(block, transformer_options['block'][0])
cache.block_index = transformer_options['block'][1]
cache.stream_count = transformer_options['block'][2]
if sb_holder.is_past_end_timestep(sigmas):
return func(*args, **kwargs)
# do sortblock stuff
x = cache.get_next_x_prev(args, kwargs)
# prepare next_x_prev
next_x_prev = cache.get_next_x_prev(args, kwargs, clone=True)
input_change = None
do_sortblock = sb_holder.should_do_easycache(sigmas)
if do_sortblock:
# TODO: checkmetadata
if cache.has_x_prev_subsampled():
input_change = (cache.subsample(x, clone=False) - cache.x_prev_subsampled).flatten().abs().mean()
if cache.has_output_prev_norm() and cache.has_relative_transformation_rate():
approx_output_change_rate = (cache.relative_transformation_rate * input_change) / cache.output_prev_norm
cache.cumulative_change_rate += approx_output_change_rate
if cache.allowed_to_skip:
# if cache.cumulative_change_rate < sb_holder.reuse_threshold:
# accumulate error + skip block
# cache.want_to_skip = True
# if cache.allowed_to_skip:
cache.consecutive_skipped_steps += 1
cache.prev_change_rate = approx_output_change_rate
return cache.apply_cache_diff(x, sb_holder)
else:
# reset error; NOT skipping block and recalculating
cache.cumulative_change_rate = 0.0
cache.prev_change_rate = approx_output_change_rate
cache.want_to_skip = False
cache.consecutive_skipped_steps = 0
# output_raw is expected to have cache.stream_count elements if count is greaater than 1 (double block, etc.)
output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]] = func(*args, **kwargs)
# if more than one stream from block, only use first one
if isinstance(output_raw, tuple):
output = output_raw[0]
else:
output = output_raw
if cache.has_output_prev_norm():
output_change = (cache.subsample(output, clone=False) - cache.output_prev_subsampled).flatten().abs().mean()
# if verbose in future
output_change_rate = output_change / cache.output_prev_norm
cache.output_change_rates.append(output_change_rate.item())
if cache.has_relative_transformation_rate():
approx_output_change_rate = (cache.relative_transformation_rate * input_change) / cache.output_prev_norm
cache.approx_output_change_rates.append(approx_output_change_rate.item())
if input_change is not None:
cache.relative_transformation_rate = output_change / input_change
# TODO: allow cache_diff to be offloaded
cache.update_cache_diff(output_raw, next_x_prev)
cache.x_prev_subsampled = cache.subsample(next_x_prev)
cache.output_prev_subsampled = cache.subsample(output)
cache.output_prev_norm = output.flatten().abs().mean()
return output_raw
return block_forward_wrapper
def prepare_block(block, sb_holder: EasySortblockHolder, stream_count: int=1):
sb_holder.add_to_all_blocks(block)
block.__original_forward = block.forward
block.forward = block_forward_factory(block.__original_forward, block)
block.__block_cache = BlockCache(subsample_factor=sb_holder.subsample_factor, verbose=sb_holder.verbose)
def clean_block(block):
block.forward = block.__original_forward
del block.__original_forward
del block.__block_cache
class BlockCache:
def __init__(self, subsample_factor: int=8, verbose: bool=False):
self.subsample_factor = subsample_factor
self.verbose = verbose
self.stream_count = 1
self.block_index = 0
# control values
self.relative_transformation_rate: float = None
self.cumulative_change_rate = 0.0
self.prev_change_rate = 0.0
# cached values
self.x_prev_subsampled: torch.Tensor = None
self.output_prev_subsampled: torch.Tensor = None
self.output_prev_norm: torch.Tensor = None
self.cache_diff: list[torch.Tensor] = []
self.output_change_rates = []
self.approx_output_change_rates = []
self.steps_skipped: list[int] = []
self.consecutive_skipped_steps = 0
# self.state_metadata = None
self.want_to_skip = False
self.allowed_to_skip = False
def has_cache_diff(self) -> bool:
return self.cache_diff[0] is not None
def has_x_prev_subsampled(self) -> bool:
return self.x_prev_subsampled is not None
def has_output_prev_subsampled(self) -> bool:
return self.output_prev_subsampled is not None
def has_output_prev_norm(self) -> bool:
return self.output_prev_norm is not None
def has_relative_transformation_rate(self) -> bool:
return self.relative_transformation_rate is not None
def get_next_x_prev(self, d_args: tuple[torch.Tensor, ...], d_kwargs: dict[str, torch.Tensor], clone: bool=False) -> tuple[torch.Tensor, ...]:
if self.stream_count == 1:
if clone:
return d_args[0].clone()
return d_args[0]
keys = list(d_kwargs.keys())[:self.stream_count]
orig_inputs = []
for key in keys:
if clone:
orig_inputs.append(d_kwargs[key].clone())
else:
orig_inputs.append(d_kwargs[key])
return tuple(orig_inputs)
def subsample(self, x: Union[torch.Tensor, tuple[torch.Tensor, ...]], clone: bool = True) -> torch.Tensor:
# subsample only the first compoenent
if isinstance(x, tuple):
return self.subsample(x[0], clone)
if self.subsample_factor > 1:
to_return = x[..., ::self.subsample_factor, ::self.subsample_factor]
if clone:
return to_return.clone()
return to_return
if clone:
return x.clone()
return x
def apply_cache_diff(self, x: Union[torch.Tensor, tuple[torch.Tensor, ...]], sb_holder: EasySortblockHolder):
self.steps_skipped.append(sb_holder.step_count)
if not isinstance(x, tuple):
x = (x, )
to_return = tuple([x[i] + self.cache_diff[i] for i in range(self.stream_count)])
if len(to_return) == 1:
return to_return[0]
return to_return
def update_cache_diff(self, output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]], x: Union[torch.Tensor, tuple[torch.Tensor, ...]]):
if not isinstance(output_raw, tuple):
output_raw = (output_raw, )
if not isinstance(x, tuple):
x = (x, )
self.cache_diff = tuple([output_raw[i] - x[i] for i in range(self.stream_count)])
def reset(self):
self.relative_transformation_rate = 0.0
self.cumulative_change_rate = 0.0
self.prev_change_rate = 0.0
self.x_prev_subsampled = None
self.output_prev_subsampled = None
self.output_prev_norm = None
self.cache_diff = []
self.output_change_rates = []
self.approx_output_change_rates = []
self.steps_skipped = []
self.consecutive_skipped_steps = 0
self.want_to_skip = False
self.allowed_to_skip = False
return self
class EasySortblockHolder:
def __init__(self, reuse_threshold: float, start_predict_ratio: float, end_predict_ratio: float, max_skipped_steps: int,
start_percent: float, end_percent: float, subsample_factor: int, verbose: bool=False):
self.name = "EasySortblock"
self.reuse_threshold = reuse_threshold
self.start_predict_ratio = start_predict_ratio
self.end_predict_ratio = end_predict_ratio
self.max_skipped_steps = max_skipped_steps
self.start_percent = start_percent
self.end_percent = end_percent
self.subsample_factor = subsample_factor
self.verbose = verbose
# timestep values
self.start_t = 0.0
self.end_t = 0.0
# control values
self.relative_transformation_rate: float = None
self.cumulative_change_rate = 0.0
self.initial_step = True
self.step_count = 0
self.predict_ratios = []
self.skip_current_step = False
self.predict_start_index = 0
# cache values
self.x_prev_subsampled: torch.Tensor = None
self.output_prev_subsampled: torch.Tensor = None
self.output_prev_norm: torch.Tensor = None
self.steps_skipped: list[int] = []
self.output_change_rates = []
self.approx_output_change_rates = []
self.state_metadata = None
self.all_blocks = []
self.blocks_per_type = {}
def add_to_all_blocks(self, block):
self.all_blocks.append(block)
def add_to_blocks_per_type(self, block, block_type: str):
self.blocks_per_type.setdefault(block_type, []).append(block)
def is_past_end_timestep(self, timestep: float) -> bool:
return not (timestep[0] > self.end_t).item()
def should_do_easycache(self, timestep: float) -> bool:
return (timestep[0] <= self.start_t).item()
def check_if_within_timesteps(self, timestep: Union[float, torch.Tensor]) -> bool:
return (timestep <= self.start_t).item() and (timestep > self.end_t).item()
def has_x_prev_subsampled(self) -> bool:
return self.x_prev_subsampled is not None
def has_output_prev_subsampled(self) -> bool:
return self.output_prev_subsampled is not None
def has_output_prev_norm(self) -> bool:
return self.output_prev_norm is not None
def has_relative_transformation_rate(self) -> bool:
return self.relative_transformation_rate is not None
def prepare_timesteps(self, model_sampling):
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
return self
def subsample(self, x: torch.Tensor, clone: bool = True) -> torch.Tensor:
if self.subsample_factor > 1:
to_return = x[..., ::self.subsample_factor, ::self.subsample_factor]
if clone:
return to_return.clone()
return to_return
if clone:
return x.clone()
return x
def check_metadata(self, x: torch.Tensor) -> bool:
metadata = (x.device, x.dtype, x.shape)
if self.state_metadata is None:
self.state_metadata = metadata
return True
if metadata == self.state_metadata:
return True
logging.warning(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
self.reset()
return False
def reset(self):
logging.info(f"EasySortblock: resetting {len(self.all_blocks)} blocks")
for block in self.all_blocks:
clean_block(block)
self.relative_transformation_rate = 0.0
self.cumulative_change_rate = 0.0
self.initial_step = True
self.step_count = 0
self.predict_ratios = []
self.skip_current_step = False
self.predict_start_index = 0
self.x_prev_subsampled = None
self.output_prev_subsampled = None
self.output_prev_norm = None
self.steps_skipped = []
self.output_change_rates = []
self.approx_output_change_rates = []
self.state_metadata = None
self.all_blocks = []
self.blocks_per_type = {}
return self
def clone(self):
return EasySortblockHolder(self.reuse_threshold, self.start_predict_ratio, self.end_predict_ratio, self.max_skipped_steps,
self.start_percent, self.end_percent, self.subsample_factor, self.verbose)
class EasySortblockScaledNode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="EasySortblockScaled",
display_name="EasySortblockScaled",
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
category="advanced/debug/model",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to add Sortblock to."),
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."),
io.Float.Input("start_predict_ratio", min=0.0, default=0.2, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."),
io.Float.Input("end_predict_ratio", min=0.0, default=0.9, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."),
io.Int.Input("policy_refresh_interval", min=3, default=5, max=100, step=1, tooltip="The interval at which to refresh the policy."),
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of Sortblock."),
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of Sortblock."),
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
],
outputs=[
io.Model.Output(tooltip="The model with Sortblock."),
],
)
@classmethod
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_predict_ratio: float, end_predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
# TODO: check for specific flavors of supported models
model = model.clone()
model.model_options["transformer_options"]["easycache"] = EasySortblockHolder(reuse_threshold, start_predict_ratio, end_predict_ratio, policy_refresh_interval, start_percent, end_percent, subsample_factor=8, verbose=verbose)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "sortblock", easysortblock_predict_noise_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "sortblock", easysortblock_outer_sample_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "sortblock", model_forward_wrapper)
return io.NodeOutput(model)
class EasySortblockExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
# EasySortblockNode,
EasySortblockScaledNode,
]
def comfy_entrypoint():
return EasySortblockExtension()

View File

@@ -105,7 +105,7 @@ class FluxKontextMultiReferenceLatentMethod:
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"reference_latents_method": (("offset", "index"), ),
"reference_latents_method": (("offset", "index", "uxo/uno"), ),
}}
RETURN_TYPES = ("CONDITIONING",)
@@ -115,6 +115,8 @@ class FluxKontextMultiReferenceLatentMethod:
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, reference_latents_method):
if "uxo" in reference_latents_method or "uso" in reference_latents_method:
reference_latents_method = "uxo"
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
return (c, )

View File

@@ -625,6 +625,37 @@ class ImageFlip:
return (image,)
class ImageScaleToMaxDimension:
upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"]
@classmethod
def INPUT_TYPES(s):
return {"required": {"image": ("IMAGE",),
"upscale_method": (s.upscale_methods,),
"largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"
CATEGORY = "image/upscaling"
def upscale(self, image, upscale_method, largest_size):
height = image.shape[1]
width = image.shape[2]
if height > width:
width = round((width / height) * largest_size)
height = largest_size
elif width > height:
height = round((height / width) * largest_size)
width = largest_size
else:
height = largest_size
width = largest_size
samples = image.movedim(-1, 1)
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = s.movedim(1, -1)
return (s,)
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
@@ -639,4 +670,5 @@ NODE_CLASS_MAPPINGS = {
"GetImageSize": GetImageSize,
"ImageRotate": ImageRotate,
"ImageFlip": ImageFlip,
"ImageScaleToMaxDimension": ImageScaleToMaxDimension,
}

View File

@@ -1,4 +1,5 @@
import torch
from torch import nn
import folder_paths
import comfy.utils
import comfy.ops
@@ -58,6 +59,136 @@ class QwenImageBlockWiseControlNet(torch.nn.Module):
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
class SigLIPMultiFeatProjModel(torch.nn.Module):
"""
SigLIP Multi-Feature Projection Model for processing style features from different layers
and projecting them into a unified hidden space.
Args:
siglip_token_nums (int): Number of SigLIP tokens, default 257
style_token_nums (int): Number of style tokens, default 256
siglip_token_dims (int): Dimension of SigLIP tokens, default 1536
hidden_size (int): Hidden layer size, default 3072
context_layer_norm (bool): Whether to use context layer normalization, default False
"""
def __init__(
self,
siglip_token_nums: int = 729,
style_token_nums: int = 64,
siglip_token_dims: int = 1152,
hidden_size: int = 3072,
context_layer_norm: bool = True,
device=None, dtype=None, operations=None
):
super().__init__()
# High-level feature processing (layer -2)
self.high_embedding_linear = nn.Sequential(
operations.Linear(siglip_token_nums, style_token_nums),
nn.SiLU()
)
self.high_layer_norm = (
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
self.high_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)
# Mid-level feature processing (layer -11)
self.mid_embedding_linear = nn.Sequential(
operations.Linear(siglip_token_nums, style_token_nums),
nn.SiLU()
)
self.mid_layer_norm = (
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
self.mid_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)
# Low-level feature processing (layer -20)
self.low_embedding_linear = nn.Sequential(
operations.Linear(siglip_token_nums, style_token_nums),
nn.SiLU()
)
self.low_layer_norm = (
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
self.low_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)
def forward(self, siglip_outputs):
"""
Forward pass function
Args:
siglip_outputs: Output from SigLIP model, containing hidden_states
Returns:
torch.Tensor: Concatenated multi-layer features with shape [bs, 3*style_token_nums, hidden_size]
"""
dtype = next(self.high_embedding_linear.parameters()).dtype
# Process high-level features (layer -2)
high_embedding = self._process_layer_features(
siglip_outputs[2],
self.high_embedding_linear,
self.high_layer_norm,
self.high_projection,
dtype
)
# Process mid-level features (layer -11)
mid_embedding = self._process_layer_features(
siglip_outputs[1],
self.mid_embedding_linear,
self.mid_layer_norm,
self.mid_projection,
dtype
)
# Process low-level features (layer -20)
low_embedding = self._process_layer_features(
siglip_outputs[0],
self.low_embedding_linear,
self.low_layer_norm,
self.low_projection,
dtype
)
# Concatenate features from all layersmodel_patch
return torch.cat((high_embedding, mid_embedding, low_embedding), dim=1)
def _process_layer_features(
self,
hidden_states: torch.Tensor,
embedding_linear: nn.Module,
layer_norm: nn.Module,
projection: nn.Module,
dtype: torch.dtype
) -> torch.Tensor:
"""
Helper function to process features from a single layer
Args:
hidden_states: Input hidden states [bs, seq_len, dim]
embedding_linear: Embedding linear layer
layer_norm: Layer normalization
projection: Projection layer
dtype: Target data type
Returns:
torch.Tensor: Processed features [bs, style_token_nums, hidden_size]
"""
# Transform dimensions: [bs, seq_len, dim] -> [bs, dim, seq_len] -> [bs, dim, style_token_nums] -> [bs, style_token_nums, dim]
embedding = embedding_linear(
hidden_states.to(dtype).transpose(1, 2)
).transpose(1, 2)
# Apply layer normalization
embedding = layer_norm(embedding)
# Project to target hidden space
embedding = projection(embedding)
return embedding
class ModelPatchLoader:
@classmethod
def INPUT_TYPES(s):
@@ -73,9 +204,14 @@ class ModelPatchLoader:
model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name)
sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True)
dtype = comfy.utils.weight_dtype(sd)
# TODO: this node will work with more types of model patches
additional_in_dim = sd["img_in.weight"].shape[1] - 64
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
if 'controlnet_blocks.0.y_rms.weight' in sd:
additional_in_dim = sd["img_in.weight"].shape[1] - 64
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
elif 'feature_embedder.mid_layer_norm.bias' in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
return (model,)
@@ -157,7 +293,51 @@ class QwenImageDiffsynthControlnet:
return (model_patched,)
class UsoStyleProjectorPatch:
def __init__(self, model_patch, encoded_image):
self.model_patch = model_patch
self.encoded_image = encoded_image
def __call__(self, kwargs):
txt_ids = kwargs.get("txt_ids")
txt = kwargs.get("txt")
siglip_embedding = self.model_patch.model(self.encoded_image.to(txt.dtype)).to(txt.dtype)
txt = torch.cat([siglip_embedding, txt], dim=1)
kwargs['txt'] = txt
kwargs['txt_ids'] = torch.cat([torch.zeros(siglip_embedding.shape[0], siglip_embedding.shape[1], 3, dtype=txt_ids.dtype, device=txt_ids.device), txt_ids], dim=1)
return kwargs
def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
self.encoded_image = self.encoded_image.to(device_or_dtype)
return self
def models(self):
return [self.model_patch]
class USOStyleReference:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
"model_patch": ("MODEL_PATCH",),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_patch"
EXPERIMENTAL = True
CATEGORY = "advanced/model_patches/flux"
def apply_patch(self, model, model_patch, clip_vision_output):
encoded_image = torch.stack((clip_vision_output.all_hidden_states[:, -20], clip_vision_output.all_hidden_states[:, -11], clip_vision_output.penultimate_hidden_states))
model_patched = model.clone()
model_patched.set_model_post_input_patch(UsoStyleProjectorPatch(model_patch, encoded_image))
return (model_patched,)
NODE_CLASS_MAPPINGS = {
"ModelPatchLoader": ModelPatchLoader,
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
"USOStyleReference": USOStyleReference,
}

View File

@@ -1,98 +1,109 @@
# Primitive nodes that are evaluated at backend.
from __future__ import annotations
import sys
from typing_extensions import override
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
from comfy_api.latest import ComfyExtension, io
class String(ComfyNodeABC):
class String(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.STRING, {})},
}
def define_schema(cls):
return io.Schema(
node_id="PrimitiveString",
display_name="String",
category="utils/primitive",
inputs=[
io.String.Input("value"),
],
outputs=[io.String.Output()],
)
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: str) -> tuple[str]:
return (value,)
class StringMultiline(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.STRING, {"multiline": True,},)},
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: str) -> tuple[str]:
return (value,)
def execute(cls, value: str) -> io.NodeOutput:
return io.NodeOutput(value)
class Int(ComfyNodeABC):
class StringMultiline(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.INT, {"min": -sys.maxsize, "max": sys.maxsize, "control_after_generate": True})},
}
def define_schema(cls):
return io.Schema(
node_id="PrimitiveStringMultiline",
display_name="String (Multiline)",
category="utils/primitive",
inputs=[
io.String.Input("value", multiline=True),
],
outputs=[io.String.Output()],
)
RETURN_TYPES = (IO.INT,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: int) -> tuple[int]:
return (value,)
class Float(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.FLOAT, {"min": -sys.maxsize, "max": sys.maxsize})},
}
RETURN_TYPES = (IO.FLOAT,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: float) -> tuple[float]:
return (value,)
def execute(cls, value: str) -> io.NodeOutput:
return io.NodeOutput(value)
class Boolean(ComfyNodeABC):
class Int(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.BOOLEAN, {})},
}
def define_schema(cls):
return io.Schema(
node_id="PrimitiveInt",
display_name="Int",
category="utils/primitive",
inputs=[
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True),
],
outputs=[io.Int.Output()],
)
RETURN_TYPES = (IO.BOOLEAN,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: bool) -> tuple[bool]:
return (value,)
@classmethod
def execute(cls, value: int) -> io.NodeOutput:
return io.NodeOutput(value)
NODE_CLASS_MAPPINGS = {
"PrimitiveString": String,
"PrimitiveStringMultiline": StringMultiline,
"PrimitiveInt": Int,
"PrimitiveFloat": Float,
"PrimitiveBoolean": Boolean,
}
class Float(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="PrimitiveFloat",
display_name="Float",
category="utils/primitive",
inputs=[
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize),
],
outputs=[io.Float.Output()],
)
NODE_DISPLAY_NAME_MAPPINGS = {
"PrimitiveString": "String",
"PrimitiveStringMultiline": "String (Multiline)",
"PrimitiveInt": "Int",
"PrimitiveFloat": "Float",
"PrimitiveBoolean": "Boolean",
}
@classmethod
def execute(cls, value: float) -> io.NodeOutput:
return io.NodeOutput(value)
class Boolean(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="PrimitiveBoolean",
display_name="Boolean",
category="utils/primitive",
inputs=[
io.Boolean.Input("value"),
],
outputs=[io.Boolean.Output()],
)
@classmethod
def execute(cls, value: bool) -> io.NodeOutput:
return io.NodeOutput(value)
class PrimitivesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
String,
StringMultiline,
Int,
Float,
Boolean,
]
async def comfy_entrypoint() -> PrimitivesExtension:
return PrimitivesExtension()

View File

@@ -1,462 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union
from comfy_api.latest import io, ComfyExtension
import comfy.patcher_extension
import logging
import torch
import math
import comfy.model_patcher
if TYPE_CHECKING:
from uuid import UUID
def prepare_noise_wrapper(executor, *args, **kwargs):
try:
transformer_options: dict[str] = args[2]["transformer_options"]
sb_holder: SortblockHolder = transformer_options["sortblock"]
if sb_holder.initial_step:
sample_sigmas = transformer_options["sample_sigmas"]
relevant_sigmas = []
# find start and end steps, then use to interpolate between start and end predict ratios
for i,sigma in enumerate(sample_sigmas):
if sb_holder.check_if_within_timesteps(sigma):
relevant_sigmas.append((i, sigma))
start_index = relevant_sigmas[0][0]
end_index = relevant_sigmas[-1][0]
sb_holder.predict_ratios = torch.linspace(sb_holder.start_predict_ratio, sb_holder.end_predict_ratio, end_index - start_index + 1)
sb_holder.predict_start_index = start_index
return executor(*args, **kwargs)
finally:
transformer_options: dict[str] = args[2]["transformer_options"]
sb_holder: SortblockHolder = transformer_options["sortblock"]
sb_holder.step_count += 1
if sb_holder.should_do_sortblock():
sb_holder.active_steps += 1
def outer_sample_wrapper(executor, *args, **kwargs):
try:
logging.info("Sortblock: inside outer_sample!")
guider = executor.class_obj
orig_model_options = guider.model_options
guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options)
# clone and prepare timesteps
sb_holder = guider.model_options["transformer_options"]["sortblock"]
guider.model_options["transformer_options"]["sortblock"] = sb_holder.clone().prepare_timesteps(guider.model_patcher.model.model_sampling)
sb_holder: SortblockHolder = guider.model_options["transformer_options"]["sortblock"]
logging.info(f"Sortblock: enabled - threshold: {sb_holder.start_predict_ratio}, start_percent: {sb_holder.start_percent}, end_percent: {sb_holder.end_percent}")
return executor(*args, **kwargs)
finally:
sb_holder = guider.model_options["transformer_options"]["sortblock"]
logging.info(f"Sortblock: final step count: {sb_holder.step_count}")
sb_holder.reset()
guider.model_options = orig_model_options
def model_forward_wrapper(executor, *args, **kwargs):
# TODO: make work with batches of conds
transformer_options: dict[str] = args[-1]
if not isinstance(transformer_options, dict):
transformer_options = kwargs.get("transformer_options")
if not transformer_options:
transformer_options = args[-2]
sigmas = transformer_options["sigmas"]
sb_holder: SortblockHolder = transformer_options["sortblock"]
sb_holder.update_should_do_sortblock(sigmas)
# if initial step, prepare everything for Sortblock
if sb_holder.initial_step:
logging.info(f"Sortblock: inside model {executor.class_obj.__class__.__name__}")
# TODO: generalize for other models
# these won't stick around past this step; should store on sb_holder instead
logging.info(f"Sortblock: preparing {len(executor.class_obj.double_blocks)} double blocks and {len(executor.class_obj.single_blocks)} single blocks")
if hasattr(executor.class_obj, "double_blocks"):
for block in executor.class_obj.double_blocks:
prepare_block(block, sb_holder)
if hasattr(executor.class_obj, "single_blocks"):
for block in executor.class_obj.single_blocks:
prepare_block(block, sb_holder)
if hasattr(executor.class_obj, "blocks"):
for block in executor.class_obj.block:
prepare_block(block, sb_holder)
# when 0: Initialization(1)
if sb_holder.step_modulus == 0:
logging.info(f"Sortblock: for step {sb_holder.step_count}, all blocks are marked for recomputation")
# all features are computed, input-outputs changes for all DiT blocks are stored for relative step 'k'
sb_holder.activated_steps.append(sb_holder.step_count)
for block in sb_holder.all_blocks:
cache: BlockCache = block.__block_cache
cache.mark_recompute()
# all block operations are performed in forward pass of model
to_return = executor(*args, **kwargs)
# when 1: Select DiT blocks(4)
if sb_holder.step_modulus == 1:
predict_index = max(0, sb_holder.step_count - sb_holder.predict_start_index)
predict_ratio = sb_holder.predict_ratios[predict_index]
logging.info(f"Sortblock: for step {sb_holder.step_count}, selecting blocks for recomputation and prediction, predict_ratio: {predict_ratio}")
reuse_ratio = 1.0 - predict_ratio
for block_type, blocks in sb_holder.blocks_per_type.items():
sorted_blocks = sorted(blocks, key=lambda x: x.__block_cache.cosine_similarity)
threshold_index = int(len(sorted_blocks) * reuse_ratio)
# blocks with lower similarity are marked for recomputation
for block in sorted_blocks[:threshold_index]:
cache: BlockCache = block.__block_cache
cache.mark_recompute()
# blocks with higher similarity are marked for prediction
for block in sorted_blocks[threshold_index:]:
cache: BlockCache = block.__block_cache
cache.mark_predict()
logging.info(f"Sortblock: for {block_type}, selected {len(sorted_blocks[:threshold_index])} blocks for recomputation and {len(sorted_blocks[threshold_index:])} blocks for prediction")
if sb_holder.initial_step:
sb_holder.initial_step = False
return to_return
def block_forward_factory(func, block):
def block_forward_wrapper(*args, **kwargs):
transformer_options: dict[str] = kwargs.get("transformer_options")
sb_holder: SortblockHolder = transformer_options["sortblock"]
cache: BlockCache = block.__block_cache
# make sure stream count is properly set for this block
if sb_holder.initial_step:
sb_holder.add_to_blocks_per_type(block, transformer_options['block'][0])
cache.block_index = transformer_options['block'][1]
cache.stream_count = transformer_options['block'][2]
# do sortblock stuff
if cache.recompute and sb_holder.step_modulus != 1:
# clone relevant inputs
orig_inputs = cache.get_orig_inputs(args, kwargs, clone=True)
# get block outputs
# NOTE: output_raw is expected to have cache.stream_count elements if count is greaater than 1 (double block, etc.)
if cache.stream_count == 1:
zzz = 10
output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]] = func(*args, **kwargs)
# perform derivative approximation;
cache.derivative_approximation(sb_holder, output_raw, orig_inputs)
# if step_modulus is 0, input-output changes for DiT block are stored
if sb_holder.step_modulus == 0:
cache.cache_previous_residual(output_raw, orig_inputs)
else:
# if not to recompute, predict features for current timestep
orig_inputs = cache.get_orig_inputs(args, kwargs, clone=False)
# when 1: Linear Prediction(2)
# if step_modulus is 1, store block residuals as 'current' after applying taylor_formula
if sb_holder.step_modulus == 1:
cache.cache_current_residual(sb_holder)
# based on features computed in last timestep, all features for current timestep are predicted using Eq. 4,
# input-output changes for all DiT blocks are stored for relative step 'k+1'
output_raw = cache.apply_linear_prediction(sb_holder, orig_inputs)
# when 1: Identify Changes(3)
if sb_holder.step_modulus == 1:
# based on features computed in last timestep, all features for current timestep are predicted using Eq. 4,
# input-output changes for all DiT blocks are stored for relative step 'k+1'
cache.calculate_cosine_similarity()
# return output_raw
return output_raw
return block_forward_wrapper
def perform_sortblock(blocks: list):
...
def prepare_block(block, sb_holder: SortblockHolder, stream_count: int=1):
sb_holder.add_to_all_blocks(block)
block.__original_forward = block.forward
block.forward = block_forward_factory(block.__original_forward, block)
block.__block_cache = BlockCache(subsample_factor=sb_holder.subsample_factor, verbose=sb_holder.verbose)
def clean_block(block):
block.forward = block.__original_forward
del block.__original_forward
del block.__block_cache
def subsample(x: torch.Tensor, factor: int, clone: bool=True) -> torch.Tensor:
if factor > 1:
to_return = x[..., ::factor, ::factor]
if clone:
return to_return.clone()
return to_return
if clone:
return x.clone()
return x
class BlockCache:
def __init__(self, subsample_factor: int=8, verbose: bool=False):
self.subsample_factor = subsample_factor
self.verbose = verbose
self.stream_count = 1
self.recompute = False
self.block_index = 0
# cached values
self.previous_residual_subsampled: torch.Tensor = None
self.current_residual_subsampled: torch.Tensor = None
self.cosine_similarity: float = None
self.previous_taylor_factors: dict[int, torch.Tensor] = {}
self.current_taylor_factors: dict[int, torch.Tensor] = {}
def mark_recompute(self):
self.recompute = True
def mark_predict(self):
self.recompute = False
def cache_previous_residual(self, output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]], orig_inputs: Union[torch.Tensor, tuple[torch.Tensor, ...]]):
if isinstance(output_raw, tuple):
output_raw = output_raw[0]
if isinstance(orig_inputs, tuple):
orig_inputs = orig_inputs[0]
del self.previous_residual_subsampled
self.previous_residual_subsampled = subsample(output_raw - orig_inputs, self.subsample_factor, clone=True)
def cache_current_residual(self, sb_holder: SortblockHolder):
del self.current_residual_subsampled
self.current_residual_subsampled = subsample(self.use_taylor_formula(sb_holder)[0], self.subsample_factor, clone=True)
def get_orig_inputs(self, d_args: tuple, d_kwargs: dict, clone: bool=True) -> tuple[torch.Tensor, ...]:
if self.stream_count == 1:
if clone:
return d_args[0].clone()
return d_args[0]
keys = list(d_kwargs.keys())[:self.stream_count]
orig_inputs = []
for key in keys:
if clone:
orig_inputs.append(d_kwargs[key].clone())
else:
orig_inputs.append(d_kwargs[key])
return tuple(orig_inputs)
def apply_linear_prediction(self, sb_holder: SortblockHolder, orig_inputs: Union[torch.Tensor, tuple[torch.Tensor, ...]]) -> None:
drop_tuple = False
if not isinstance(orig_inputs, tuple):
orig_inputs = (orig_inputs,)
drop_tuple = True
taylor_results = self.use_taylor_formula(sb_holder)
for output, taylor_result in zip(orig_inputs, taylor_results):
if output.shape != taylor_result.shape:
zzz = 10
output += taylor_result
if drop_tuple:
orig_inputs = orig_inputs[0]
return orig_inputs
def calculate_cosine_similarity(self) -> None:
self.cosine_similarity = torch.nn.functional.cosine_similarity(self.previous_residual_subsampled, self.current_residual_subsampled, dim=-1).mean().item()
def derivative_approximation(self, sb_holder: SortblockHolder, output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]], orig_inputs: Union[torch.Tensor, tuple[torch.Tensor, ...]]):
activation_distance = sb_holder.activated_steps[-1] - sb_holder.activated_steps[-2]
# make tuple if not already tuple, so that works with both single and double blocks
if not isinstance(output_raw, tuple):
output_raw = (output_raw,)
if not isinstance(orig_inputs, tuple):
orig_inputs = (orig_inputs,)
for i, (output, x) in enumerate(zip(output_raw, orig_inputs)):
feature = output.clone() - x
has_previous_taylor_factor = self.previous_taylor_factors.get(i, None) is not None
# NOTE: not sure why - 2, but that's what's in the original implementation. Maybe consider changing values?
if has_previous_taylor_factor and sb_holder.step_count > (sb_holder.first_enhance - 2):
self.current_taylor_factors[i] = (
feature - self.previous_taylor_factors[i]
) / activation_distance
self.previous_taylor_factors[i] = feature
def use_taylor_formula(self, sb_holder: SortblockHolder) -> tuple[torch.Tensor, ...]:
step_distance = sb_holder.step_count - sb_holder.activated_steps[-1]
output_predicted = []
for key in self.previous_taylor_factors.keys():
previous_tf = self.previous_taylor_factors[key]
current_tf = self.current_taylor_factors[key]
predicted = taylor_formula(previous_tf, 0, step_distance)
predicted += taylor_formula(current_tf, 1, step_distance)
output_predicted.append(predicted)
return tuple(output_predicted)
def reset(self):
self.recompute = False
self.current_residual_subsampled = None
self.previous_residual_subsampled = None
self.cosine_similarity = None
self.previous_taylor_factors = {}
self.current_taylor_factors = {}
def taylor_formula(taylor_factor: torch.Tensor, i: int, step_distance: int):
return (
(1 / math.factorial(i))
* taylor_factor
* (step_distance ** i)
)
class SortblockHolder:
def __init__(self, start_predict_ratio: float, end_predict_ratio: float, policy_refresh_interval: int,
start_percent: float, end_percent: float, subsample_factor: int=8, verbose: bool=False):
self.start_predict_ratio = start_predict_ratio
self.end_predict_ratio = end_predict_ratio
self.start_percent = start_percent
self.end_percent = end_percent
self.subsample_factor = subsample_factor
self.verbose = verbose
# NOTE: number represents steps
self.policy_refresh_interval = policy_refresh_interval
self.active_policy_refresh_interval = 1
self.first_enhance = 3 # NOTE: this value is 2 higher than the one actually used in code (subtracted by 2 in derivative_approximation)
# timestep values
self.start_t = 0.0
self.end_t = 0.0
self.curr_t = 0.0
# control values
self.initial_step = True
self.step_count = 0
self.activated_steps: list[int] = [0]
self.step_modulus = 0
self.do_sortblock = False
self.active_steps = 0
self.predict_ratios = []
self.predict_start_index = 0
# cache values
self.all_blocks = []
self.blocks_per_type = {}
def add_to_all_blocks(self, block):
self.all_blocks.append(block)
def add_to_blocks_per_type(self, block, block_type: str):
self.blocks_per_type.setdefault(block_type, []).append(block)
def prepare_timesteps(self, model_sampling):
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
return self
def check_if_within_timesteps(self, timestep: Union[float, torch.Tensor]) -> bool:
return (timestep <= self.start_t).item() and (timestep > self.end_t).item()
def update_should_do_sortblock(self, timestep: float) -> bool:
self.do_sortblock = (timestep[0] <= self.start_t).item() and (timestep[0] > self.end_t).item()
self.curr_t = timestep
if self.do_sortblock:
self.active_policy_refresh_interval = self.policy_refresh_interval
else:
self.active_policy_refresh_interval = 1
self.update_step_modulus()
return self.do_sortblock
def update_step_modulus(self):
self.step_modulus = int(self.step_count % self.active_policy_refresh_interval)
def should_do_sortblock(self) -> bool:
return self.do_sortblock
def reset(self):
self.initial_step = True
self.curr_t = 0.0
logging.info(f"Sortblock: resetting {len(self.all_blocks)} blocks")
for block in self.all_blocks:
clean_block(block)
self.all_blocks = []
self.blocks_per_type = {}
self.step_count = 0
self.activated_steps = [0]
self.step_modulus = 0
self.active_steps = 0
self.predict_ratios = []
self.do_sortblock = False
self.predict_start_index = 0
return self
def clone(self):
return SortblockHolder(start_predict_ratio=self.start_predict_ratio, end_predict_ratio=self.end_predict_ratio, policy_refresh_interval=self.policy_refresh_interval,
start_percent=self.start_percent, end_percent=self.end_percent, subsample_factor=self.subsample_factor,
verbose=self.verbose)
class SortblockNode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="Sortblock",
display_name="Sortblock",
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
category="advanced/debug/model",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to add Sortblock to."),
io.Float.Input("predict_ratio", min=0.0, default=0.8, max=3.0, step=0.01, tooltip="The ratio of blocks to predict."),
io.Int.Input("policy_refresh_interval", min=3, default=5, max=100, step=1, tooltip="The interval at which to refresh the policy."),
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of Sortblock."),
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of Sortblock."),
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
],
outputs=[
io.Model.Output(tooltip="The model with Sortblock."),
],
)
@classmethod
def execute(cls, model: io.Model.Type, predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
# TODO: check for specific flavors of supported models
model = model.clone()
model.model_options["transformer_options"]["sortblock"] = SortblockHolder(start_predict_ratio=predict_ratio, end_predict_ratio=predict_ratio, policy_refresh_interval=policy_refresh_interval,
start_percent=start_percent, end_percent=end_percent, subsample_factor=8, verbose=verbose)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "sortblock", prepare_noise_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "sortblock", outer_sample_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "sortblock", model_forward_wrapper)
return io.NodeOutput(model)
class SortblockScaledNode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="SortblockScaled",
display_name="SortblockScaled",
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
category="advanced/debug/model",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to add Sortblock to."),
io.Float.Input("start_predict_ratio", min=0.0, default=0.2, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."),
io.Float.Input("end_predict_ratio", min=0.0, default=0.9, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."),
io.Int.Input("policy_refresh_interval", min=3, default=5, max=100, step=1, tooltip="The interval at which to refresh the policy."),
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of Sortblock."),
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of Sortblock."),
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
],
outputs=[
io.Model.Output(tooltip="The model with Sortblock."),
],
)
@classmethod
def execute(cls, model: io.Model.Type, start_predict_ratio: float, end_predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
# TODO: check for specific flavors of supported models
model = model.clone()
model.model_options["transformer_options"]["sortblock"] = SortblockHolder(start_predict_ratio, end_predict_ratio, policy_refresh_interval, start_percent, end_percent, subsample_factor=8, verbose=verbose)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "sortblock", prepare_noise_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "sortblock", outer_sample_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "sortblock", model_forward_wrapper)
return io.NodeOutput(model)
class SortblockExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SortblockNode,
SortblockScaledNode,
]
def comfy_entrypoint():
return SortblockExtension()

View File

@@ -17,55 +17,61 @@
"""
import torch
import nodes
from typing_extensions import override
import comfy.utils
import nodes
from comfy_api.latest import ComfyExtension, io
class StableCascade_EmptyLatentImage:
def __init__(self, device="cpu"):
self.device = device
class StableCascade_EmptyLatentImage(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StableCascade_EmptyLatentImage",
category="latent/stable_cascade",
inputs=[
io.Int.Input("width", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("compression", default=42, min=4, max=128, step=1),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(display_name="stage_c"),
io.Latent.Output(display_name="stage_b"),
],
)
@classmethod
def INPUT_TYPES(s):
return {"required": {
"width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})
}}
RETURN_TYPES = ("LATENT", "LATENT")
RETURN_NAMES = ("stage_c", "stage_b")
FUNCTION = "generate"
CATEGORY = "latent/stable_cascade"
def generate(self, width, height, compression, batch_size=1):
def execute(cls, width, height, compression, batch_size=1):
c_latent = torch.zeros([batch_size, 16, height // compression, width // compression])
b_latent = torch.zeros([batch_size, 4, height // 4, width // 4])
return ({
return io.NodeOutput({
"samples": c_latent,
}, {
"samples": b_latent,
})
class StableCascade_StageC_VAEEncode:
def __init__(self, device="cpu"):
self.device = device
class StableCascade_StageC_VAEEncode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StableCascade_StageC_VAEEncode",
category="latent/stable_cascade",
inputs=[
io.Image.Input("image"),
io.Vae.Input("vae"),
io.Int.Input("compression", default=42, min=4, max=128, step=1),
],
outputs=[
io.Latent.Output(display_name="stage_c"),
io.Latent.Output(display_name="stage_b"),
],
)
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE",),
"vae": ("VAE", ),
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
}}
RETURN_TYPES = ("LATENT", "LATENT")
RETURN_NAMES = ("stage_c", "stage_b")
FUNCTION = "generate"
CATEGORY = "latent/stable_cascade"
def generate(self, image, vae, compression):
def execute(cls, image, vae, compression):
width = image.shape[-2]
height = image.shape[-3]
out_width = (width // compression) * vae.downscale_ratio
@@ -75,51 +81,59 @@ class StableCascade_StageC_VAEEncode:
c_latent = vae.encode(s[:,:,:,:3])
b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2])
return ({
return io.NodeOutput({
"samples": c_latent,
}, {
"samples": b_latent,
})
class StableCascade_StageB_Conditioning:
class StableCascade_StageB_Conditioning(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "conditioning": ("CONDITIONING",),
"stage_c": ("LATENT",),
}}
RETURN_TYPES = ("CONDITIONING",)
def define_schema(cls):
return io.Schema(
node_id="StableCascade_StageB_Conditioning",
category="conditioning/stable_cascade",
inputs=[
io.Conditioning.Input("conditioning"),
io.Latent.Input("stage_c"),
],
outputs=[
io.Conditioning.Output(),
],
)
FUNCTION = "set_prior"
CATEGORY = "conditioning/stable_cascade"
def set_prior(self, conditioning, stage_c):
@classmethod
def execute(cls, conditioning, stage_c):
c = []
for t in conditioning:
d = t[1].copy()
d['stable_cascade_prior'] = stage_c['samples']
d["stable_cascade_prior"] = stage_c["samples"]
n = [t[0], d]
c.append(n)
return (c, )
return io.NodeOutput(c)
class StableCascade_SuperResolutionControlnet:
def __init__(self, device="cpu"):
self.device = device
class StableCascade_SuperResolutionControlnet(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StableCascade_SuperResolutionControlnet",
category="_for_testing/stable_cascade",
is_experimental=True,
inputs=[
io.Image.Input("image"),
io.Vae.Input("vae"),
],
outputs=[
io.Image.Output(display_name="controlnet_input"),
io.Latent.Output(display_name="stage_c"),
io.Latent.Output(display_name="stage_b"),
],
)
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE",),
"vae": ("VAE", ),
}}
RETURN_TYPES = ("IMAGE", "LATENT", "LATENT")
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
FUNCTION = "generate"
EXPERIMENTAL = True
CATEGORY = "_for_testing/stable_cascade"
def generate(self, image, vae):
def execute(cls, image, vae):
width = image.shape[-2]
height = image.shape[-3]
batch_size = image.shape[0]
@@ -127,15 +141,22 @@ class StableCascade_SuperResolutionControlnet:
c_latent = torch.zeros([batch_size, 16, height // 16, width // 16])
b_latent = torch.zeros([batch_size, 4, height // 2, width // 2])
return (controlnet_input, {
return io.NodeOutput(controlnet_input, {
"samples": c_latent,
}, {
"samples": b_latent,
})
NODE_CLASS_MAPPINGS = {
"StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage,
"StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning,
"StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode,
"StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet,
}
class StableCascadeExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
StableCascade_EmptyLatentImage,
StableCascade_StageB_Conditioning,
StableCascade_StageC_VAEEncode,
StableCascade_SuperResolutionControlnet,
]
async def comfy_entrypoint() -> StableCascadeExtension:
return StableCascadeExtension()

View File

@@ -5,52 +5,49 @@ import av
import torch
import folder_paths
import json
from typing import Optional, Literal
from typing import Optional
from typing_extensions import override
from fractions import Fraction
from comfy.comfy_types import IO, FileLocator, ComfyNodeABC
from comfy_api.latest import Input, InputImpl, Types
from comfy_api.input import AudioInput, ImageInput, VideoInput
from comfy_api.input_impl import VideoFromComponents, VideoFromFile
from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
from comfy_api.latest import ComfyExtension, io, ui
from comfy.cli_args import args
class SaveWEBM:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
class SaveWEBM(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveWEBM",
category="image/video",
is_experimental=True,
inputs=[
io.Image.Input("images"),
io.String.Input("filename_prefix", default="ComfyUI"),
io.Combo.Input("codec", options=["vp9", "av1"]),
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"codec": (["vp9", "av1"],),
"fps": ("FLOAT", {"default": 24.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"crf": ("FLOAT", {"default": 32.0, "min": 0, "max": 63.0, "step": 1, "tooltip": "Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "image/video"
EXPERIMENTAL = True
def save_images(self, images, codec, fps, filename_prefix, crf, prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
def execute(cls, images, codec, fps, filename_prefix, crf) -> io.NodeOutput:
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0]
)
file = f"{filename}_{counter:05}_.webm"
container = av.open(os.path.join(full_output_folder, file), mode="w")
if prompt is not None:
container.metadata["prompt"] = json.dumps(prompt)
if cls.hidden.prompt is not None:
container.metadata["prompt"] = json.dumps(cls.hidden.prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
container.metadata[x] = json.dumps(extra_pnginfo[x])
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"}
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
@@ -69,63 +66,46 @@ class SaveWEBM:
container.mux(stream.encode())
container.close()
results: list[FileLocator] = [{
"filename": file,
"subfolder": subfolder,
"type": self.type
}]
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side
class SaveVideo(ComfyNodeABC):
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type: Literal["output"] = "output"
self.prefix_append = ""
class SaveVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveVideo",
display_name="Save Video",
category="image/video",
description="Saves the input images to your ComfyUI output directory.",
inputs=[
io.Video.Input("video", tooltip="The video to save."),
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"video": (IO.VIDEO, {"tooltip": "The video to save."}),
"filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}),
"format": (Types.VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}),
"codec": (Types.VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}),
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO"
},
}
RETURN_TYPES = ()
FUNCTION = "save_video"
OUTPUT_NODE = True
CATEGORY = "image/video"
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
def save_video(self, video: Input.Video, filename_prefix, format, codec, prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput:
width, height = video.get_dimensions()
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix,
self.output_dir,
folder_paths.get_output_directory(),
width,
height
)
results: list[FileLocator] = list()
saved_metadata = None
if not args.disable_metadata:
metadata = {}
if extra_pnginfo is not None:
metadata.update(extra_pnginfo)
if prompt is not None:
metadata["prompt"] = prompt
if cls.hidden.extra_pnginfo is not None:
metadata.update(cls.hidden.extra_pnginfo)
if cls.hidden.prompt is not None:
metadata["prompt"] = cls.hidden.prompt
if len(metadata) > 0:
saved_metadata = metadata
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
video.save_to(
os.path.join(full_output_folder, file),
format=format,
@@ -133,83 +113,82 @@ class SaveVideo(ComfyNodeABC):
metadata=saved_metadata
)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
return { "ui": { "images": results, "animated": (True,) } }
class CreateVideo(ComfyNodeABC):
class CreateVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": (IO.IMAGE, {"tooltip": "The images to create a video from."}),
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 1.0}),
},
"optional": {
"audio": (IO.AUDIO, {"tooltip": "The audio to add to the video."}),
}
}
def define_schema(cls):
return io.Schema(
node_id="CreateVideo",
display_name="Create Video",
category="image/video",
description="Create a video from images.",
inputs=[
io.Image.Input("images", tooltip="The images to create a video from."),
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
],
outputs=[
io.Video.Output(),
],
)
RETURN_TYPES = (IO.VIDEO,)
FUNCTION = "create_video"
CATEGORY = "image/video"
DESCRIPTION = "Create a video from images."
def create_video(self, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None):
return (InputImpl.VideoFromComponents(
Types.VideoComponents(
images=images,
audio=audio,
frame_rate=Fraction(fps),
)
),)
class GetVideoComponents(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"video": (IO.VIDEO, {"tooltip": "The video to extract components from."}),
}
}
RETURN_TYPES = (IO.IMAGE, IO.AUDIO, IO.FLOAT)
RETURN_NAMES = ("images", "audio", "fps")
FUNCTION = "get_components"
def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput:
return io.NodeOutput(
VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
)
CATEGORY = "image/video"
DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate."
class GetVideoComponents(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="GetVideoComponents",
display_name="Get Video Components",
category="image/video",
description="Extracts all components from a video: frames, audio, and framerate.",
inputs=[
io.Video.Input("video", tooltip="The video to extract components from."),
],
outputs=[
io.Image.Output(display_name="images"),
io.Audio.Output(display_name="audio"),
io.Float.Output(display_name="fps"),
],
)
def get_components(self, video: Input.Video):
@classmethod
def execute(cls, video: VideoInput) -> io.NodeOutput:
components = video.get_components()
return (components.images, components.audio, float(components.frame_rate))
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
class LoadVideo(ComfyNodeABC):
class LoadVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
def define_schema(cls):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
files = folder_paths.filter_files_content_types(files, ["video"])
return {"required":
{"file": (sorted(files), {"video_upload": True})},
}
CATEGORY = "image/video"
RETURN_TYPES = (IO.VIDEO,)
FUNCTION = "load_video"
def load_video(self, file):
video_path = folder_paths.get_annotated_filepath(file)
return (InputImpl.VideoFromFile(video_path),)
return io.Schema(
node_id="LoadVideo",
display_name="Load Video",
category="image/video",
inputs=[
io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video),
],
outputs=[
io.Video.Output(),
],
)
@classmethod
def IS_CHANGED(cls, file):
def execute(cls, file) -> io.NodeOutput:
video_path = folder_paths.get_annotated_filepath(file)
return io.NodeOutput(VideoFromFile(video_path))
@classmethod
def fingerprint_inputs(s, file):
video_path = folder_paths.get_annotated_filepath(file)
mod_time = os.path.getmtime(video_path)
# Instead of hashing the file, we can just use the modification time to avoid
@@ -217,24 +196,23 @@ class LoadVideo(ComfyNodeABC):
return mod_time
@classmethod
def VALIDATE_INPUTS(cls, file):
def validate_inputs(s, file):
if not folder_paths.exists_annotated_filepath(file):
return "Invalid video file: {}".format(file)
return True
NODE_CLASS_MAPPINGS = {
"SaveWEBM": SaveWEBM,
"SaveVideo": SaveVideo,
"CreateVideo": CreateVideo,
"GetVideoComponents": GetVideoComponents,
"LoadVideo": LoadVideo,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SaveVideo": "Save Video",
"CreateVideo": "Create Video",
"GetVideoComponents": "Get Video Components",
"LoadVideo": "Load Video",
}
class VideoExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SaveWEBM,
SaveVideo,
CreateVideo,
GetVideoComponents,
LoadVideo,
]
async def comfy_entrypoint() -> VideoExtension:
return VideoExtension()

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.56"
__version__ = "0.3.57"

View File

@@ -113,7 +113,6 @@ import gc
if os.name == "nt":
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
if __name__ == "__main__":
if args.default_device is not None:

View File

@@ -2325,8 +2325,6 @@ async def init_builtin_extra_nodes():
"nodes_model_patch.py",
"nodes_easycache.py",
"nodes_audio_encoder.py",
"nodes_sortblock.py",
"nodes_easysortblock.py",
]
import_failed = []
@@ -2346,6 +2344,7 @@ async def init_builtin_api_nodes():
"nodes_veo2.py",
"nodes_kling.py",
"nodes_bfl.py",
"nodes_bytedance.py",
"nodes_luma.py",
"nodes_recraft.py",
"nodes_pixverse.py",

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.56"
version = "0.3.57"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.25.11
comfyui-workflow-templates==0.1.70
comfyui-workflow-templates==0.1.75
comfyui-embedded-docs==0.2.6
torch
torchsde

View File

@@ -3,11 +3,7 @@ from urllib import request
#This is the ComfyUI api prompt format.
#If you want it for a specific workflow you can "enable dev mode options"
#in the settings of the UI (gear beside the "Queue Size: ") this will enable
#a button on the UI to save workflows in api format.
#keep in mind ComfyUI is pre alpha software so this format will change a bit.
#If you want it for a specific workflow you can "File -> Export (API)" in the interface.
#this is the one for the default workflow
prompt_text = """

View File

@@ -729,7 +729,34 @@ class PromptServer():
@routes.post("/interrupt")
async def post_interrupt(request):
nodes.interrupt_processing()
try:
json_data = await request.json()
except json.JSONDecodeError:
json_data = {}
# Check if a specific prompt_id was provided for targeted interruption
prompt_id = json_data.get('prompt_id')
if prompt_id:
currently_running, _ = self.prompt_queue.get_current_queue()
# Check if the prompt_id matches any currently running prompt
should_interrupt = False
for item in currently_running:
# item structure: (number, prompt_id, prompt, extra_data, outputs_to_execute)
if item[1] == prompt_id:
logging.info(f"Interrupting prompt {prompt_id}")
should_interrupt = True
break
if should_interrupt:
nodes.interrupt_processing()
else:
logging.info(f"Prompt {prompt_id} is not currently running, skipping interrupt")
else:
# No prompt_id provided, do a global interrupt
logging.info("Global interrupt (no prompt_id specified)")
nodes.interrupt_processing()
return web.Response(status=200)
@routes.post("/free")