Compare commits

...

7 Commits

Author SHA1 Message Date
bymyself
090c1dd3e6 fix: move essentials_category to correct replacement nodes
Move essentials_category from deprecated/incorrect nodes to their replacements:
- ImageBatch → BatchImagesNode (ImageBatch is deprecated)
- Blur → removed (should use subgraph blueprint)
- GetVideoComponents → Video Slice

Amp-Thread-ID: https://ampcode.com/threads/T-019c8340-4da2-723b-a09f-83895c5bbda5
2026-02-26 00:40:57 -08:00
comfyanonymous
8a4d85c708 Cleanups to the last PR. (#12646) 2026-02-26 01:30:31 -05:00
Tavi Halperin
a4522017c5 feat: per-guide attention strength control in self-attention (#12518)
Implements per-guide attention attenuation via log-space additive bias
in self-attention. Each guide reference tracks its own strength and
optional spatial mask in conditioning metadata (guide_attention_entries).
2026-02-26 01:25:23 -05:00
Jukka Seppänen
907e5dcbbf initial FlowRVS support (#12637) 2026-02-25 23:38:46 -05:00
comfyanonymous
7253531670 Fix ltxav te mem estimation. (#12643) 2026-02-25 23:13:47 -05:00
comfyanonymous
e14b04478c Fix LTXAV text enc min length. (#12640)
Should have been 1024 instead of 512
2026-02-25 22:36:02 -05:00
Christian Byrne
eb8737d675 Update requirements.txt (#12642) 2026-02-25 18:30:48 -08:00
16 changed files with 385 additions and 24 deletions

View File

@@ -4,6 +4,25 @@ import comfy.utils
import logging
def is_equal(x, y):
if torch.is_tensor(x) and torch.is_tensor(y):
return torch.equal(x, y)
elif isinstance(x, dict) and isinstance(y, dict):
if x.keys() != y.keys():
return False
return all(is_equal(x[k], y[k]) for k in x)
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
if type(x) is not type(y) or len(x) != len(y):
return False
return all(is_equal(a, b) for a, b in zip(x, y))
else:
try:
return x == y
except Exception:
logging.warning("comparison issue with COND")
return False
class CONDRegular:
def __init__(self, cond):
self.cond = cond
@@ -84,7 +103,7 @@ class CONDConstant(CONDRegular):
return self._copy_with(self.cond)
def can_concat(self, other):
if self.cond != other.cond:
if not is_equal(self.cond, other.cond):
return False
return True

View File

@@ -218,7 +218,7 @@ class BasicAVTransformerBlock(nn.Module):
def forward(
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
run_vx = transformer_options.get("run_vx", True)
run_ax = transformer_options.get("run_ax", True)
@@ -234,7 +234,7 @@ class BasicAVTransformerBlock(nn.Module):
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
del vshift_msa, vscale_msa
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options)
del norm_vx
# video cross-attention
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
@@ -726,7 +726,7 @@ class LTXAVModel(LTXVModel):
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
def _process_transformer_blocks(
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs
):
vx = x[0]
ax = x[1]
@@ -770,6 +770,7 @@ class LTXAVModel(LTXVModel):
v_cross_gate_timestep=args["v_cross_gate_timestep"],
a_cross_gate_timestep=args["a_cross_gate_timestep"],
transformer_options=args["transformer_options"],
self_attention_mask=args.get("self_attention_mask"),
)
return out
@@ -790,6 +791,7 @@ class LTXAVModel(LTXVModel):
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
"transformer_options": transformer_options,
"self_attention_mask": self_attention_mask,
},
{"original_block": block_wrap},
)
@@ -811,6 +813,7 @@ class LTXAVModel(LTXVModel):
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
)
return [vx, ax]

View File

@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from enum import Enum
import functools
import logging
import math
from typing import Dict, Optional, Tuple
@@ -14,6 +15,8 @@ import comfy.ldm.common_dit
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
logger = logging.getLogger(__name__)
def _log_base(x, base):
return np.log(x) / np.log(base)
@@ -415,12 +418,12 @@ class BasicTransformerBlock(nn.Module):
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
attn1_input = comfy.ldm.common_dit.rms_norm(x)
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options)
x.addcmul_(attn1_input, gate_msa)
del attn1_input
@@ -638,8 +641,16 @@ class LTXBaseModel(torch.nn.Module, ABC):
"""Process input data. Must be implemented by subclasses."""
pass
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
"""Build self-attention mask for per-guide attention attenuation.
Base implementation returns None (no attenuation). Subclasses that
support guide-based attention control should override this.
"""
return None
@abstractmethod
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs):
"""Process transformer blocks. Must be implemented by subclasses."""
pass
@@ -788,9 +799,17 @@ class LTXBaseModel(torch.nn.Module, ABC):
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
# Build self-attention mask for per-guide attenuation
self_attention_mask = self._build_guide_self_attention_mask(
x, transformer_options, merged_args
)
# Process transformer blocks
x = self._process_transformer_blocks(
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
x, context, attention_mask, timestep, pe,
transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
**merged_args,
)
# Process output
@@ -890,13 +909,243 @@ class LTXVModel(LTXBaseModel):
pixel_coords = pixel_coords[:, :, grid_mask, ...]
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
# Compute per-guide surviving token counts from guide_attention_entries.
# Each entry tracks one guide reference; they are appended in order and
# their pre_filter_counts partition the kf_grid_mask.
guide_entries = kwargs.get("guide_attention_entries", None)
if guide_entries:
total_pfc = sum(e["pre_filter_count"] for e in guide_entries)
if total_pfc != len(kf_grid_mask):
raise ValueError(
f"guide pre_filter_counts ({total_pfc}) != "
f"keyframe grid mask length ({len(kf_grid_mask)})"
)
resolved_entries = []
offset = 0
for entry in guide_entries:
pfc = entry["pre_filter_count"]
entry_mask = kf_grid_mask[offset:offset + pfc]
surviving = int(entry_mask.sum().item())
resolved_entries.append({
**entry,
"surviving_count": surviving,
})
offset += pfc
additional_args["resolved_guide_entries"] = resolved_entries
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
# Total surviving guide tokens (all guides)
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
x = self.patchify_proj(x)
return x, pixel_coords, additional_args
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
"""Build self-attention mask for per-guide attention attenuation.
Reads resolved_guide_entries from merged_args (computed in _process_input)
to build a log-space additive bias mask that attenuates noisy ↔ guide
attention for each guide reference independently.
Returns None if no attenuation is needed (all strengths == 1.0 and no
spatial masks, or no guide tokens).
"""
if isinstance(x, list):
# AV model: x = [vx, ax]; use vx for token count and device
total_tokens = x[0].shape[1]
device = x[0].device
dtype = x[0].dtype
else:
total_tokens = x.shape[1]
device = x.device
dtype = x.dtype
num_guide_tokens = merged_args.get("num_guide_tokens", 0)
if num_guide_tokens == 0:
return None
resolved_entries = merged_args.get("resolved_guide_entries", None)
if not resolved_entries:
return None
# Check if any attenuation is actually needed
needs_attenuation = any(
e["strength"] < 1.0 or e.get("pixel_mask") is not None
for e in resolved_entries
)
if not needs_attenuation:
return None
# Build per-guide-token weights for all tracked guide tokens.
# Guides are appended in order at the end of the sequence.
guide_start = total_tokens - num_guide_tokens
all_weights = []
total_tracked = 0
for entry in resolved_entries:
surviving = entry["surviving_count"]
if surviving == 0:
continue
strength = entry["strength"]
pixel_mask = entry.get("pixel_mask")
latent_shape = entry.get("latent_shape")
if pixel_mask is not None and latent_shape is not None:
f_lat, h_lat, w_lat = latent_shape
per_token = self._downsample_mask_to_latent(
pixel_mask.to(device=device, dtype=dtype),
f_lat, h_lat, w_lat,
)
# per_token shape: (B, f_lat*h_lat*w_lat).
# Collapse batch dim — the mask is assumed identical across the
# batch; validate and take the first element to get (1, tokens).
if per_token.shape[0] > 1:
ref = per_token[0]
for bi in range(1, per_token.shape[0]):
if not torch.equal(ref, per_token[bi]):
logger.warning(
"pixel_mask differs across batch elements; "
"using first element only."
)
break
per_token = per_token[:1]
# `surviving` is the post-grid_mask token count.
# Clamp to surviving to handle any mismatch safely.
n_weights = min(per_token.shape[1], surviving)
weights = per_token[:, :n_weights] * strength # (1, n_weights)
else:
weights = torch.full(
(1, surviving), strength, device=device, dtype=dtype
)
all_weights.append(weights)
total_tracked += weights.shape[1]
if not all_weights:
return None
# Concatenate per-token weights for all tracked guides
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
if (tracked_weights >= 1.0).all():
return None
# Build the mask: guide tokens are at the end of the sequence.
# Tracked guides come first (in order), untracked follow.
return self._build_self_attention_mask(
total_tokens, num_guide_tokens, total_tracked,
tracked_weights, guide_start, device, dtype,
)
@staticmethod
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
"""Downsample a pixel-space mask to per-token latent weights.
Args:
mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1].
f_lat: Number of latent frames (pre-dilation original count).
h_lat: Latent height (pre-dilation original height).
w_lat: Latent width (pre-dilation original width).
Returns:
(B, F_lat * H_lat * W_lat) flattened per-token weights.
"""
b = mask.shape[0]
f_pix = mask.shape[2]
# Spatial downsampling: area interpolation per frame
spatial_down = torch.nn.functional.interpolate(
rearrange(mask, "b 1 f h w -> (b f) 1 h w"),
size=(h_lat, w_lat),
mode="area",
)
spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b)
# Temporal downsampling: first pixel frame maps to first latent frame,
# remaining pixel frames are averaged in groups for causal temporal structure.
first_frame = spatial_down[:, :, :1, :, :]
if f_pix > 1 and f_lat > 1:
remaining_pix = f_pix - 1
remaining_lat = f_lat - 1
t = remaining_pix // remaining_lat
if t < 1:
# Fewer pixel frames than latent frames — upsample by repeating
# the available pixel frames via nearest interpolation.
rest_flat = rearrange(
spatial_down[:, :, 1:, :, :],
"b 1 f h w -> (b h w) 1 f",
)
rest_up = torch.nn.functional.interpolate(
rest_flat, size=remaining_lat, mode="nearest",
)
rest = rearrange(
rest_up, "(b h w) 1 f -> b 1 f h w",
b=b, h=h_lat, w=w_lat,
)
else:
# Trim trailing pixel frames that don't fill a complete group
usable = remaining_lat * t
rest = rearrange(
spatial_down[:, :, 1:1 + usable, :, :],
"b 1 (f t) h w -> b 1 f t h w",
t=t,
)
rest = rest.mean(dim=3)
latent_mask = torch.cat([first_frame, rest], dim=2)
elif f_lat > 1:
# Single pixel frame but multiple latent frames — repeat the
# single frame across all latent frames.
latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1)
else:
latent_mask = first_frame
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
@staticmethod
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
tracked_weights, guide_start, device, dtype):
"""Build a log-space additive self-attention bias mask.
Attenuates attention between noisy tokens and tracked guide tokens.
Untracked guide tokens (at the end of the guide portion) keep full attention.
Args:
total_tokens: Total sequence length.
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
tracked_count: Number of tracked guide tokens (first in the guide portion).
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
guide_start: Index where guide tokens begin in the sequence.
device: Target device.
dtype: Target dtype.
Returns:
(1, 1, total_tokens, total_tokens) additive bias mask.
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
"""
finfo = torch.finfo(dtype)
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
tracked_end = guide_start + tracked_count
# Convert weights to log-space bias
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
log_w = torch.full_like(w, finfo.min)
positive_mask = w > 0
if positive_mask.any():
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
# noisy → tracked guides: each noisy row gets the same per-guide weight
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
return mask
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
"""Process transformer blocks for LTXV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
@@ -906,10 +1155,10 @@ class LTXVModel(LTXBaseModel):
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
@@ -919,6 +1168,7 @@ class LTXVModel(LTXBaseModel):
timestep=timestep,
pe=pe,
transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
)
return x

View File

@@ -459,6 +459,7 @@ class WanVAE(nn.Module):
attn_scales=[],
temperal_downsample=[True, True, False],
image_channels=3,
conv_out_channels=3,
dropout=0.0):
super().__init__()
self.dim = dim
@@ -474,7 +475,7 @@ class WanVAE(nn.Module):
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
self.decoder = Decoder3d(dim, z_dim, conv_out_channels, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def encode(self, x):

View File

@@ -76,6 +76,7 @@ class ModelType(Enum):
FLUX = 8
IMG_TO_IMG = 9
FLOW_COSMOS = 10
IMG_TO_IMG_FLOW = 11
def model_sampling(model_config, model_type):
@@ -108,6 +109,8 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.FLOW_COSMOS:
c = comfy.model_sampling.COSMOS_RFLOW
s = comfy.model_sampling.ModelSamplingCosmosRFlow
elif model_type == ModelType.IMG_TO_IMG_FLOW:
c = comfy.model_sampling.IMG_TO_IMG_FLOW
class ModelSampling(s, c):
pass
@@ -971,6 +974,10 @@ class LTXV(BaseModel):
if keyframe_idxs is not None:
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
@@ -1023,6 +1030,10 @@ class LTXAV(BaseModel):
if latent_shapes is not None:
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
return out
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
@@ -1466,6 +1477,12 @@ class WAN22(WAN21):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class WAN21_FlowRVS(WAN21):
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
model_config.unet_config["model_type"] = "t2v"
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
self.image_to_video = image_to_video
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)

View File

@@ -509,6 +509,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if ref_conv_weight is not None:
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
if metadata is not None and "config" in metadata:
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
return dit_config
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D

View File

@@ -83,6 +83,16 @@ class IMG_TO_IMG(X0):
def calculate_input(self, sigma, noise):
return noise
class IMG_TO_IMG_FLOW(CONST):
def calculate_denoised(self, sigma, model_output, model_input):
return model_output
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
return latent_image
def inverse_noise_scaling(self, sigma, latent):
return 1.0 - latent
class COSMOS_RFLOW:
def calculate_input(self, sigma, noise):
sigma = (sigma / (sigma + 1))

View File

@@ -694,8 +694,9 @@ class VAE:
self.latent_dim = 3
self.latent_channels = 16
self.output_channels = sd["encoder.conv1.weight"].shape[1]
self.conv_out_channels = sd["decoder.head.2.weight"].shape[0]
self.pad_channel_value = 1.0
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "conv_out_channels": self.conv_out_channels, "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)

View File

@@ -1256,6 +1256,16 @@ class WAN22_T2V(WAN21_T2V):
out = model_base.WAN22(self, image_to_video=True, device=device)
return out
class WAN21_FlowRVS(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "flow_rvs",
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device)
return out
class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan3d2",
@@ -1667,6 +1677,6 @@ class ACEStep15(supported_models_base.BASE):
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models += [SVD_img2vid]

View File

@@ -6,6 +6,7 @@ import comfy.text_encoders.genmo
import torch
import comfy.utils
import math
import itertools
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -72,7 +73,7 @@ class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1024, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
@@ -199,8 +200,10 @@ class LTXAVTEModel(torch.nn.Module):
constant /= 2.0
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
num_tokens = max(num_tokens, 64)
m = min([sum(1 for _ in itertools.takewhile(lambda x: x[0] == 0, sub)) for sub in token_weight_pairs])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) - m
num_tokens = max(num_tokens, 642)
return num_tokens * constant * 1024 * 1024
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):

View File

@@ -134,6 +134,36 @@ class LTXVImgToVideoInplace(io.ComfyNode):
generate = execute # TODO: remove
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0):
"""Append a guide_attention_entry to both positive and negative conditioning.
Each entry tracks one guide reference for per-reference attention control.
Entries are derived independently from each conditioning to avoid cross-contamination.
"""
new_entry = {
"pre_filter_count": pre_filter_count,
"strength": strength,
"pixel_mask": None,
"latent_shape": latent_shape,
}
results = []
for cond in (positive, negative):
# Read existing entries from this specific conditioning
existing = []
for t in cond:
found = t[1].get("guide_attention_entries", None)
if found is not None:
existing = found
break
# Shallow copy and append (no deepcopy needed — entries contain
# only scalars and None for pixel_mask at this call site).
entries = [*existing, new_entry]
results.append(node_helpers.conditioning_set_values(
cond, {"guide_attention_entries": entries}
))
return results[0], results[1]
def conditioning_get_any_value(conditioning, key, default=None):
for t in conditioning:
if key in t[1]:
@@ -324,6 +354,13 @@ class LTXVAddGuide(io.ComfyNode):
scale_factors,
)
# Track this guide for per-reference attention control.
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
positive, negative = _append_guide_attention_entry(
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
)
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
generate = execute # TODO: remove
@@ -359,8 +396,14 @@ class LTXVCropGuides(io.ComfyNode):
latent_image = latent_image[:, :, :-num_keyframes]
noise_mask = noise_mask[:, :, :-num_keyframes]
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
positive = node_helpers.conditioning_set_values(positive, {
"keyframe_idxs": None,
"guide_attention_entries": None,
})
negative = node_helpers.conditioning_set_values(negative, {
"keyframe_idxs": None,
"guide_attention_entries": None,
})
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})

View File

@@ -52,7 +52,7 @@ class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],),
"sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img", "img_to_img_flow"],),
"zsnr": ("BOOLEAN", {"default": False, "advanced": True}),
}}
@@ -76,6 +76,8 @@ class ModelSamplingDiscrete:
sampling_type = comfy.model_sampling.X0
elif sampling == "img_to_img":
sampling_type = comfy.model_sampling.IMG_TO_IMG
elif sampling == "img_to_img_flow":
sampling_type = comfy.model_sampling.IMG_TO_IMG_FLOW
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass

View File

@@ -79,7 +79,6 @@ class Blur(io.ComfyNode):
node_id="ImageBlur",
display_name="Image Blur",
category="image/postprocessing",
essentials_category="Image Tools",
inputs=[
io.Image.Input("image"),
io.Int.Input("blur_radius", default=1, min=1, max=31, step=1),
@@ -568,6 +567,7 @@ class BatchImagesNode(io.ComfyNode):
node_id="BatchImagesNode",
display_name="Batch Images",
category="image",
essentials_category="Image Tools",
search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"],
inputs=[
io.Autogrow.Input("images", template=autogrow_template)

View File

@@ -147,7 +147,6 @@ class GetVideoComponents(io.ComfyNode):
search_aliases=["extract frames", "split video", "video to images", "demux"],
display_name="Get Video Components",
category="image/video",
essentials_category="Video Tools",
description="Extracts all components from a video: frames, audio, and framerate.",
inputs=[
io.Video.Input("video", tooltip="The video to extract components from."),
@@ -218,6 +217,7 @@ class VideoSlice(io.ComfyNode):
"start time",
],
category="image/video",
essentials_category="Video Tools",
inputs=[
io.Video.Input("video"),
io.Float.Input(

View File

@@ -1925,7 +1925,6 @@ class ImageInvert:
class ImageBatch:
SEARCH_ALIASES = ["combine images", "merge images", "stack images"]
ESSENTIALS_CATEGORY = "Image Tools"
@classmethod
def INPUT_TYPES(s):

View File

@@ -1,4 +1,4 @@
comfyui-frontend-package==1.39.16
comfyui-frontend-package==1.39.19
comfyui-workflow-templates==0.9.3
comfyui-embedded-docs==0.4.3
torch