mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 14:19:57 +00:00
Compare commits
28 Commits
claude/sla
...
fix-essent
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
090c1dd3e6 | ||
|
|
8a4d85c708 | ||
|
|
a4522017c5 | ||
|
|
907e5dcbbf | ||
|
|
7253531670 | ||
|
|
e14b04478c | ||
|
|
eb8737d675 | ||
|
|
0467f690a8 | ||
|
|
4f5b7dbf1f | ||
|
|
3ebe1ac22e | ||
|
|
befa83d434 | ||
|
|
33f83d53ae | ||
|
|
b874bd2b8c | ||
|
|
0aa02453bb | ||
|
|
599f9c5010 | ||
|
|
11fefa58e9 | ||
|
|
d8090013b8 | ||
|
|
048dd2f321 | ||
|
|
84aba95e03 | ||
|
|
9b1c63eb69 | ||
|
|
7a7debcaf1 | ||
|
|
dba2766e53 | ||
|
|
caa43d2395 | ||
|
|
07ca6852e8 | ||
|
|
f266b8d352 | ||
|
|
b6cb30bab5 | ||
|
|
ee72752162 | ||
|
|
7591d781a7 |
@@ -1,6 +1,7 @@
|
||||
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||
language: "en-US"
|
||||
early_access: false
|
||||
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
|
||||
|
||||
reviews:
|
||||
profile: "chill"
|
||||
@@ -35,6 +36,14 @@ reviews:
|
||||
- "!**/*.bat"
|
||||
|
||||
path_instructions:
|
||||
- path: "**"
|
||||
instructions: |
|
||||
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
|
||||
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
|
||||
de-indented, or reformatted without logic changes. If code appears in the diff
|
||||
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
|
||||
treat it as unchanged. Contributors should not feel obligated to address
|
||||
pre-existing issues outside the scope of their contribution.
|
||||
- path: "comfy/**"
|
||||
instructions: |
|
||||
Core ML/diffusion engine. Focus on:
|
||||
@@ -74,7 +83,11 @@ reviews:
|
||||
auto_review:
|
||||
enabled: true
|
||||
auto_incremental_review: true
|
||||
drafts: true
|
||||
drafts: false
|
||||
ignore_title_keywords:
|
||||
- "WIP"
|
||||
- "DO NOT REVIEW"
|
||||
- "DO NOT MERGE"
|
||||
|
||||
finishing_touches:
|
||||
docstrings:
|
||||
@@ -84,7 +97,7 @@ reviews:
|
||||
|
||||
tools:
|
||||
ruff:
|
||||
enabled: true
|
||||
enabled: false
|
||||
pylint:
|
||||
enabled: false
|
||||
flake8:
|
||||
|
||||
@@ -46,6 +46,8 @@ class NodeReplaceManager:
|
||||
connections: dict[str, list[tuple[str, str, int]]] = {}
|
||||
need_replacement: set[str] = set()
|
||||
for node_number, node_struct in prompt.items():
|
||||
if "class_type" not in node_struct or "inputs" not in node_struct:
|
||||
continue
|
||||
class_type = node_struct["class_type"]
|
||||
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
|
||||
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
|
||||
|
||||
@@ -4,6 +4,25 @@ import comfy.utils
|
||||
import logging
|
||||
|
||||
|
||||
def is_equal(x, y):
|
||||
if torch.is_tensor(x) and torch.is_tensor(y):
|
||||
return torch.equal(x, y)
|
||||
elif isinstance(x, dict) and isinstance(y, dict):
|
||||
if x.keys() != y.keys():
|
||||
return False
|
||||
return all(is_equal(x[k], y[k]) for k in x)
|
||||
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
|
||||
if type(x) is not type(y) or len(x) != len(y):
|
||||
return False
|
||||
return all(is_equal(a, b) for a, b in zip(x, y))
|
||||
else:
|
||||
try:
|
||||
return x == y
|
||||
except Exception:
|
||||
logging.warning("comparison issue with COND")
|
||||
return False
|
||||
|
||||
|
||||
class CONDRegular:
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
@@ -84,7 +103,7 @@ class CONDConstant(CONDRegular):
|
||||
return self._copy_with(self.cond)
|
||||
|
||||
def can_concat(self, other):
|
||||
if self.cond != other.cond:
|
||||
if not is_equal(self.cond, other.cond):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import (
|
||||
LTXVModel,
|
||||
)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
class CompressedTimestep:
|
||||
@@ -217,7 +218,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
def forward(
|
||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
|
||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
run_vx = transformer_options.get("run_vx", True)
|
||||
run_ax = transformer_options.get("run_ax", True)
|
||||
@@ -233,7 +234,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||
del vshift_msa, vscale_msa
|
||||
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
|
||||
attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options)
|
||||
del norm_vx
|
||||
# video cross-attention
|
||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||
@@ -450,6 +451,29 @@ class LTXAVModel(LTXVModel):
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
split_rope=True,
|
||||
double_precision_rope=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
split_rope=True,
|
||||
double_precision_rope=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
def preprocess_text_embeds(self, context):
|
||||
if context.shape[-1] == self.caption_channels * 2:
|
||||
return context
|
||||
out_vid = self.video_embeddings_connector(context)[0]
|
||||
out_audio = self.audio_embeddings_connector(context)[0]
|
||||
return torch.concat((out_vid, out_audio), dim=-1)
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
"""Initialize transformer blocks for LTXAV."""
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
@@ -702,7 +726,7 @@ class LTXAVModel(LTXVModel):
|
||||
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
||||
|
||||
def _process_transformer_blocks(
|
||||
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
|
||||
self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs
|
||||
):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
@@ -746,6 +770,7 @@ class LTXAVModel(LTXVModel):
|
||||
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
||||
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||
transformer_options=args["transformer_options"],
|
||||
self_attention_mask=args.get("self_attention_mask"),
|
||||
)
|
||||
return out
|
||||
|
||||
@@ -766,6 +791,7 @@ class LTXAVModel(LTXVModel):
|
||||
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
||||
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||
"transformer_options": transformer_options,
|
||||
"self_attention_mask": self_attention_mask,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
@@ -787,6 +813,7 @@ class LTXAVModel(LTXVModel):
|
||||
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
||||
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||
transformer_options=transformer_options,
|
||||
self_attention_mask=self_attention_mask,
|
||||
)
|
||||
|
||||
return [vx, ax]
|
||||
|
||||
@@ -157,11 +157,9 @@ class Embeddings1DConnector(nn.Module):
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
if self.num_learnable_registers:
|
||||
self.learnable_registers = nn.Parameter(
|
||||
torch.rand(
|
||||
torch.empty(
|
||||
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
|
||||
)
|
||||
* 2.0
|
||||
- 1.0
|
||||
)
|
||||
|
||||
def get_fractional_positions(self, indices_grid):
|
||||
@@ -234,7 +232,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
|
||||
return indices
|
||||
|
||||
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
||||
def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None):
|
||||
dim = self.inner_dim
|
||||
n_elem = 2 # 2 because of cos and sin
|
||||
freqs = self.precompute_freqs(indices_grid, spacing)
|
||||
@@ -247,7 +245,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
)
|
||||
else:
|
||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
|
||||
return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -288,7 +286,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
indices_grid = indices_grid[None, None, :]
|
||||
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
||||
freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype)
|
||||
|
||||
# 2. Blocks
|
||||
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
@@ -14,6 +15,8 @@ import comfy.ldm.common_dit
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _log_base(x, base):
|
||||
return np.log(x) / np.log(base)
|
||||
|
||||
@@ -415,12 +418,12 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
|
||||
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
||||
attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options)
|
||||
x.addcmul_(attn1_input, gate_msa)
|
||||
del attn1_input
|
||||
|
||||
@@ -638,8 +641,16 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
"""Process input data. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
|
||||
"""Build self-attention mask for per-guide attention attenuation.
|
||||
|
||||
Base implementation returns None (no attenuation). Subclasses that
|
||||
support guide-based attention control should override this.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs):
|
||||
"""Process transformer blocks. Must be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@@ -788,9 +799,17 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
||||
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
||||
|
||||
# Build self-attention mask for per-guide attenuation
|
||||
self_attention_mask = self._build_guide_self_attention_mask(
|
||||
x, transformer_options, merged_args
|
||||
)
|
||||
|
||||
# Process transformer blocks
|
||||
x = self._process_transformer_blocks(
|
||||
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
|
||||
x, context, attention_mask, timestep, pe,
|
||||
transformer_options=transformer_options,
|
||||
self_attention_mask=self_attention_mask,
|
||||
**merged_args,
|
||||
)
|
||||
|
||||
# Process output
|
||||
@@ -890,13 +909,243 @@ class LTXVModel(LTXBaseModel):
|
||||
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
||||
|
||||
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
||||
|
||||
# Compute per-guide surviving token counts from guide_attention_entries.
|
||||
# Each entry tracks one guide reference; they are appended in order and
|
||||
# their pre_filter_counts partition the kf_grid_mask.
|
||||
guide_entries = kwargs.get("guide_attention_entries", None)
|
||||
if guide_entries:
|
||||
total_pfc = sum(e["pre_filter_count"] for e in guide_entries)
|
||||
if total_pfc != len(kf_grid_mask):
|
||||
raise ValueError(
|
||||
f"guide pre_filter_counts ({total_pfc}) != "
|
||||
f"keyframe grid mask length ({len(kf_grid_mask)})"
|
||||
)
|
||||
resolved_entries = []
|
||||
offset = 0
|
||||
for entry in guide_entries:
|
||||
pfc = entry["pre_filter_count"]
|
||||
entry_mask = kf_grid_mask[offset:offset + pfc]
|
||||
surviving = int(entry_mask.sum().item())
|
||||
resolved_entries.append({
|
||||
**entry,
|
||||
"surviving_count": surviving,
|
||||
})
|
||||
offset += pfc
|
||||
additional_args["resolved_guide_entries"] = resolved_entries
|
||||
|
||||
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||
|
||||
# Total surviving guide tokens (all guides)
|
||||
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
|
||||
|
||||
x = self.patchify_proj(x)
|
||||
return x, pixel_coords, additional_args
|
||||
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
|
||||
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
|
||||
"""Build self-attention mask for per-guide attention attenuation.
|
||||
|
||||
Reads resolved_guide_entries from merged_args (computed in _process_input)
|
||||
to build a log-space additive bias mask that attenuates noisy ↔ guide
|
||||
attention for each guide reference independently.
|
||||
|
||||
Returns None if no attenuation is needed (all strengths == 1.0 and no
|
||||
spatial masks, or no guide tokens).
|
||||
"""
|
||||
if isinstance(x, list):
|
||||
# AV model: x = [vx, ax]; use vx for token count and device
|
||||
total_tokens = x[0].shape[1]
|
||||
device = x[0].device
|
||||
dtype = x[0].dtype
|
||||
else:
|
||||
total_tokens = x.shape[1]
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
|
||||
num_guide_tokens = merged_args.get("num_guide_tokens", 0)
|
||||
if num_guide_tokens == 0:
|
||||
return None
|
||||
|
||||
resolved_entries = merged_args.get("resolved_guide_entries", None)
|
||||
if not resolved_entries:
|
||||
return None
|
||||
|
||||
# Check if any attenuation is actually needed
|
||||
needs_attenuation = any(
|
||||
e["strength"] < 1.0 or e.get("pixel_mask") is not None
|
||||
for e in resolved_entries
|
||||
)
|
||||
if not needs_attenuation:
|
||||
return None
|
||||
|
||||
# Build per-guide-token weights for all tracked guide tokens.
|
||||
# Guides are appended in order at the end of the sequence.
|
||||
guide_start = total_tokens - num_guide_tokens
|
||||
all_weights = []
|
||||
total_tracked = 0
|
||||
|
||||
for entry in resolved_entries:
|
||||
surviving = entry["surviving_count"]
|
||||
if surviving == 0:
|
||||
continue
|
||||
|
||||
strength = entry["strength"]
|
||||
pixel_mask = entry.get("pixel_mask")
|
||||
latent_shape = entry.get("latent_shape")
|
||||
|
||||
if pixel_mask is not None and latent_shape is not None:
|
||||
f_lat, h_lat, w_lat = latent_shape
|
||||
per_token = self._downsample_mask_to_latent(
|
||||
pixel_mask.to(device=device, dtype=dtype),
|
||||
f_lat, h_lat, w_lat,
|
||||
)
|
||||
# per_token shape: (B, f_lat*h_lat*w_lat).
|
||||
# Collapse batch dim — the mask is assumed identical across the
|
||||
# batch; validate and take the first element to get (1, tokens).
|
||||
if per_token.shape[0] > 1:
|
||||
ref = per_token[0]
|
||||
for bi in range(1, per_token.shape[0]):
|
||||
if not torch.equal(ref, per_token[bi]):
|
||||
logger.warning(
|
||||
"pixel_mask differs across batch elements; "
|
||||
"using first element only."
|
||||
)
|
||||
break
|
||||
per_token = per_token[:1]
|
||||
# `surviving` is the post-grid_mask token count.
|
||||
# Clamp to surviving to handle any mismatch safely.
|
||||
n_weights = min(per_token.shape[1], surviving)
|
||||
weights = per_token[:, :n_weights] * strength # (1, n_weights)
|
||||
else:
|
||||
weights = torch.full(
|
||||
(1, surviving), strength, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
all_weights.append(weights)
|
||||
total_tracked += weights.shape[1]
|
||||
|
||||
if not all_weights:
|
||||
return None
|
||||
|
||||
# Concatenate per-token weights for all tracked guides
|
||||
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
|
||||
|
||||
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
|
||||
if (tracked_weights >= 1.0).all():
|
||||
return None
|
||||
|
||||
# Build the mask: guide tokens are at the end of the sequence.
|
||||
# Tracked guides come first (in order), untracked follow.
|
||||
return self._build_self_attention_mask(
|
||||
total_tokens, num_guide_tokens, total_tracked,
|
||||
tracked_weights, guide_start, device, dtype,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
|
||||
"""Downsample a pixel-space mask to per-token latent weights.
|
||||
|
||||
Args:
|
||||
mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1].
|
||||
f_lat: Number of latent frames (pre-dilation original count).
|
||||
h_lat: Latent height (pre-dilation original height).
|
||||
w_lat: Latent width (pre-dilation original width).
|
||||
|
||||
Returns:
|
||||
(B, F_lat * H_lat * W_lat) flattened per-token weights.
|
||||
"""
|
||||
b = mask.shape[0]
|
||||
f_pix = mask.shape[2]
|
||||
|
||||
# Spatial downsampling: area interpolation per frame
|
||||
spatial_down = torch.nn.functional.interpolate(
|
||||
rearrange(mask, "b 1 f h w -> (b f) 1 h w"),
|
||||
size=(h_lat, w_lat),
|
||||
mode="area",
|
||||
)
|
||||
spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b)
|
||||
|
||||
# Temporal downsampling: first pixel frame maps to first latent frame,
|
||||
# remaining pixel frames are averaged in groups for causal temporal structure.
|
||||
first_frame = spatial_down[:, :, :1, :, :]
|
||||
if f_pix > 1 and f_lat > 1:
|
||||
remaining_pix = f_pix - 1
|
||||
remaining_lat = f_lat - 1
|
||||
t = remaining_pix // remaining_lat
|
||||
if t < 1:
|
||||
# Fewer pixel frames than latent frames — upsample by repeating
|
||||
# the available pixel frames via nearest interpolation.
|
||||
rest_flat = rearrange(
|
||||
spatial_down[:, :, 1:, :, :],
|
||||
"b 1 f h w -> (b h w) 1 f",
|
||||
)
|
||||
rest_up = torch.nn.functional.interpolate(
|
||||
rest_flat, size=remaining_lat, mode="nearest",
|
||||
)
|
||||
rest = rearrange(
|
||||
rest_up, "(b h w) 1 f -> b 1 f h w",
|
||||
b=b, h=h_lat, w=w_lat,
|
||||
)
|
||||
else:
|
||||
# Trim trailing pixel frames that don't fill a complete group
|
||||
usable = remaining_lat * t
|
||||
rest = rearrange(
|
||||
spatial_down[:, :, 1:1 + usable, :, :],
|
||||
"b 1 (f t) h w -> b 1 f t h w",
|
||||
t=t,
|
||||
)
|
||||
rest = rest.mean(dim=3)
|
||||
latent_mask = torch.cat([first_frame, rest], dim=2)
|
||||
elif f_lat > 1:
|
||||
# Single pixel frame but multiple latent frames — repeat the
|
||||
# single frame across all latent frames.
|
||||
latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1)
|
||||
else:
|
||||
latent_mask = first_frame
|
||||
|
||||
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
|
||||
|
||||
@staticmethod
|
||||
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
|
||||
tracked_weights, guide_start, device, dtype):
|
||||
"""Build a log-space additive self-attention bias mask.
|
||||
|
||||
Attenuates attention between noisy tokens and tracked guide tokens.
|
||||
Untracked guide tokens (at the end of the guide portion) keep full attention.
|
||||
|
||||
Args:
|
||||
total_tokens: Total sequence length.
|
||||
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
|
||||
tracked_count: Number of tracked guide tokens (first in the guide portion).
|
||||
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
|
||||
guide_start: Index where guide tokens begin in the sequence.
|
||||
device: Target device.
|
||||
dtype: Target dtype.
|
||||
|
||||
Returns:
|
||||
(1, 1, total_tokens, total_tokens) additive bias mask.
|
||||
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
|
||||
"""
|
||||
finfo = torch.finfo(dtype)
|
||||
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
|
||||
tracked_end = guide_start + tracked_count
|
||||
|
||||
# Convert weights to log-space bias
|
||||
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
|
||||
log_w = torch.full_like(w, finfo.min)
|
||||
positive_mask = w > 0
|
||||
if positive_mask.any():
|
||||
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
|
||||
|
||||
# noisy → tracked guides: each noisy row gets the same per-guide weight
|
||||
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
|
||||
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
|
||||
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
|
||||
|
||||
return mask
|
||||
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
|
||||
"""Process transformer blocks for LTXV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
@@ -906,10 +1155,10 @@ class LTXVModel(LTXBaseModel):
|
||||
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(
|
||||
@@ -919,6 +1168,7 @@ class LTXVModel(LTXBaseModel):
|
||||
timestep=timestep,
|
||||
pe=pe,
|
||||
transformer_options=transformer_options,
|
||||
self_attention_mask=self_attention_mask,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
@@ -459,6 +459,7 @@ class WanVAE(nn.Module):
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
image_channels=3,
|
||||
conv_out_channels=3,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -474,7 +475,7 @@ class WanVAE(nn.Module):
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
|
||||
self.decoder = Decoder3d(dim, z_dim, conv_out_channels, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def encode(self, x):
|
||||
|
||||
@@ -76,6 +76,7 @@ class ModelType(Enum):
|
||||
FLUX = 8
|
||||
IMG_TO_IMG = 9
|
||||
FLOW_COSMOS = 10
|
||||
IMG_TO_IMG_FLOW = 11
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
@@ -108,6 +109,8 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.FLOW_COSMOS:
|
||||
c = comfy.model_sampling.COSMOS_RFLOW
|
||||
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||
elif model_type == ModelType.IMG_TO_IMG_FLOW:
|
||||
c = comfy.model_sampling.IMG_TO_IMG_FLOW
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@@ -971,6 +974,10 @@ class LTXV(BaseModel):
|
||||
if keyframe_idxs is not None:
|
||||
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||
|
||||
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||
if guide_attention_entries is not None:
|
||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||
@@ -988,10 +995,14 @@ class LTXAV(BaseModel):
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||
@@ -1019,6 +1030,10 @@ class LTXAV(BaseModel):
|
||||
if latent_shapes is not None:
|
||||
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||
|
||||
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||
if guide_attention_entries is not None:
|
||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||
@@ -1462,6 +1477,12 @@ class WAN22(WAN21):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class WAN21_FlowRVS(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
|
||||
model_config.unet_config["model_type"] = "t2v"
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
class Hunyuan3Dv2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||
|
||||
@@ -509,6 +509,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if ref_conv_weight is not None:
|
||||
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
|
||||
|
||||
if metadata is not None and "config" in metadata:
|
||||
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||
|
||||
@@ -271,6 +271,7 @@ class ModelPatcher:
|
||||
self.is_clip = False
|
||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||
|
||||
self.cached_patcher_init: tuple[Callable, tuple] | None = None
|
||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
@@ -307,8 +308,15 @@ class ModelPatcher:
|
||||
def get_free_memory(self, device):
|
||||
return comfy.model_management.get_free_memory(device)
|
||||
|
||||
def clone(self):
|
||||
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||
def clone(self, disable_dynamic=False):
|
||||
class_ = self.__class__
|
||||
model = self.model
|
||||
if self.is_dynamic() and disable_dynamic:
|
||||
class_ = ModelPatcher
|
||||
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
||||
model = temp_model_patcher.model
|
||||
|
||||
n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
@@ -362,6 +370,8 @@ class ModelPatcher:
|
||||
n.is_clip = self.is_clip
|
||||
n.hook_mode = self.hook_mode
|
||||
|
||||
n.cached_patcher_init = self.cached_patcher_init
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||
callback(self, n)
|
||||
return n
|
||||
|
||||
@@ -83,6 +83,16 @@ class IMG_TO_IMG(X0):
|
||||
def calculate_input(self, sigma, noise):
|
||||
return noise
|
||||
|
||||
class IMG_TO_IMG_FLOW(CONST):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
return model_output
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
return latent_image
|
||||
|
||||
def inverse_noise_scaling(self, sigma, latent):
|
||||
return 1.0 - latent
|
||||
|
||||
class COSMOS_RFLOW:
|
||||
def calculate_input(self, sigma, noise):
|
||||
sigma = (sigma / (sigma + 1))
|
||||
|
||||
10
comfy/ops.py
10
comfy/ops.py
@@ -19,7 +19,7 @@
|
||||
import torch
|
||||
import logging
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
import json
|
||||
import comfy.memory_management
|
||||
@@ -296,7 +296,7 @@ class disable_weight_init:
|
||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
||||
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
return
|
||||
|
||||
@@ -317,7 +317,7 @@ class disable_weight_init:
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
||||
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||
@@ -827,6 +827,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
else:
|
||||
sd = {}
|
||||
|
||||
if not hasattr(self, 'weight'):
|
||||
logging.warning("Warning: state dict on uninitialized op {}".format(prefix))
|
||||
return sd
|
||||
|
||||
if self.bias is not None:
|
||||
sd["{}bias".format(prefix)] = self.bias
|
||||
|
||||
|
||||
32
comfy/sd.py
32
comfy/sd.py
@@ -694,8 +694,9 @@ class VAE:
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = 16
|
||||
self.output_channels = sd["encoder.conv1.weight"].shape[1]
|
||||
self.conv_out_channels = sd["decoder.head.2.weight"].shape[0]
|
||||
self.pad_channel_value = 1.0
|
||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
|
||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "conv_out_channels": self.conv_out_channels, "dropout": 0.0}
|
||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
@@ -1530,14 +1531,24 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
|
||||
return (model, clip, vae)
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||
if out is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||
if output_model:
|
||||
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||
return out
|
||||
|
||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
|
||||
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
|
||||
embedding_directory=embedding_directory,
|
||||
model_options=model_options,
|
||||
te_model_options=te_model_options,
|
||||
disable_dynamic=disable_dynamic)
|
||||
return model
|
||||
|
||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False):
|
||||
clip = None
|
||||
clipvision = None
|
||||
vae = None
|
||||
@@ -1586,7 +1597,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
if output_model:
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
||||
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
|
||||
|
||||
if output_vae:
|
||||
@@ -1637,7 +1649,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable_dynamic=False):
|
||||
"""
|
||||
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||
|
||||
@@ -1721,7 +1733,8 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
||||
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
if not model_management.is_device_cpu(offload_device):
|
||||
model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
|
||||
@@ -1730,12 +1743,13 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
logging.info("left over keys in diffusion model: {}".format(left_over))
|
||||
return model_patcher
|
||||
|
||||
def load_diffusion_model(unet_path, model_options={}):
|
||||
def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
|
||||
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||
if model is None:
|
||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
|
||||
return model
|
||||
|
||||
def load_unet(unet_path, dtype=None):
|
||||
|
||||
@@ -1256,6 +1256,16 @@ class WAN22_T2V(WAN21_T2V):
|
||||
out = model_base.WAN22(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
class WAN21_FlowRVS(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "flow_rvs",
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan3d2",
|
||||
@@ -1667,6 +1677,6 @@ class ACEStep15(supported_models_base.BASE):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -3,10 +3,10 @@ import os
|
||||
from transformers import T5TokenizerFast
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
import comfy.text_encoders.genmo
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
import torch
|
||||
import comfy.utils
|
||||
import math
|
||||
import itertools
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@@ -73,7 +73,7 @@ class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1024, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
|
||||
@@ -102,6 +102,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.dtypes = set()
|
||||
self.dtypes.add(dtype)
|
||||
self.compat_mode = False
|
||||
|
||||
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
|
||||
self.dtypes.add(dtype_llama)
|
||||
@@ -109,6 +110,11 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
operations = self.gemma3_12b.operations # TODO
|
||||
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def enable_compat_mode(self): # TODO: remove
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
operations = self.gemma3_12b.operations
|
||||
dtype = self.text_embedding_projection.weight.dtype
|
||||
device = self.text_embedding_projection.weight.device
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
split_rope=True,
|
||||
double_precision_rope=True,
|
||||
@@ -124,6 +130,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.compat_mode = True
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.execution_device = options.get("execution_device", self.execution_device)
|
||||
@@ -146,9 +153,11 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
out = out.reshape((out.shape[0], out.shape[1], -1))
|
||||
out = self.text_embedding_projection(out)
|
||||
out = out.float()
|
||||
out_vid = self.video_embeddings_connector(out)[0]
|
||||
out_audio = self.audio_embeddings_connector(out)[0]
|
||||
out = torch.concat((out_vid, out_audio), dim=-1)
|
||||
|
||||
if self.compat_mode:
|
||||
out_vid = self.video_embeddings_connector(out)[0]
|
||||
out_audio = self.audio_embeddings_connector(out)[0]
|
||||
out = torch.concat((out_vid, out_audio), dim=-1)
|
||||
|
||||
return out.to(out_device), pooled
|
||||
|
||||
@@ -159,20 +168,30 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||
return self.gemma3_12b.load_sd(sd)
|
||||
else:
|
||||
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
|
||||
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True)
|
||||
if len(sdo) == 0:
|
||||
sdo = sd
|
||||
|
||||
missing_all = []
|
||||
unexpected_all = []
|
||||
|
||||
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]:
|
||||
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection)]:
|
||||
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
|
||||
if component_sd:
|
||||
missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||
missing_all.extend([f"{prefix}{k}" for k in missing])
|
||||
unexpected_all.extend([f"{prefix}{k}" for k in unexpected])
|
||||
|
||||
if "model.diffusion_model.audio_embeddings_connector.transformer_1d_blocks.2.attn1.to_q.bias" not in sd: # TODO: remove
|
||||
ww = sd.get("model.diffusion_model.audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.bias", None)
|
||||
if ww is not None:
|
||||
if ww.shape[0] == 3840:
|
||||
self.enable_compat_mode()
|
||||
sdv = comfy.utils.state_dict_prefix_replace(sd, {"model.diffusion_model.video_embeddings_connector.": ""}, filter_keys=True)
|
||||
self.video_embeddings_connector.load_state_dict(sdv, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||
sda = comfy.utils.state_dict_prefix_replace(sd, {"model.diffusion_model.audio_embeddings_connector.": ""}, filter_keys=True)
|
||||
self.audio_embeddings_connector.load_state_dict(sda, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||
|
||||
return (missing_all, unexpected_all)
|
||||
|
||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||
@@ -181,8 +200,10 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
constant /= 2.0
|
||||
|
||||
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
||||
num_tokens = max(num_tokens, 64)
|
||||
m = min([sum(1 for _ in itertools.takewhile(lambda x: x[0] == 0, sub)) for sub in token_weight_pairs])
|
||||
|
||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) - m
|
||||
num_tokens = max(num_tokens, 642)
|
||||
return num_tokens * constant * 1024 * 1024
|
||||
|
||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
|
||||
@@ -29,7 +29,7 @@ import itertools
|
||||
from torch.nn.functional import interpolate
|
||||
from tqdm.auto import trange
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
from comfy.cli_args import args
|
||||
import json
|
||||
import time
|
||||
import mmap
|
||||
@@ -113,7 +113,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
metadata = None
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
try:
|
||||
if enables_dynamic_vram():
|
||||
if comfy.memory_management.aimdo_enabled:
|
||||
sd, metadata = load_safetensors(ckpt)
|
||||
if not return_metadata:
|
||||
metadata = None
|
||||
|
||||
@@ -27,6 +27,7 @@ class Seedream4TaskCreationRequest(BaseModel):
|
||||
sequential_image_generation: str = Field("disabled")
|
||||
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
||||
watermark: bool = Field(False)
|
||||
output_format: str | None = None
|
||||
|
||||
|
||||
class ImageTaskCreationResponse(BaseModel):
|
||||
@@ -106,6 +107,7 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
||||
("2496x1664 (3:2)", 2496, 1664),
|
||||
("1664x2496 (2:3)", 1664, 2496),
|
||||
("3024x1296 (21:9)", 3024, 1296),
|
||||
("3072x3072 (1:1)", 3072, 3072),
|
||||
("4096x4096 (1:1)", 4096, 4096),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
@@ -134,6 +134,13 @@ class ImageToVideoWithAudioRequest(BaseModel):
|
||||
shot_type: str | None = Field(None)
|
||||
|
||||
|
||||
class KlingAvatarRequest(BaseModel):
|
||||
image: str = Field(...)
|
||||
sound_file: str = Field(...)
|
||||
prompt: str | None = Field(None)
|
||||
mode: str = Field(...)
|
||||
|
||||
|
||||
class MotionControlRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
image_url: str = Field(...)
|
||||
|
||||
@@ -37,6 +37,12 @@ from comfy_api_nodes.util import (
|
||||
|
||||
BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations"
|
||||
|
||||
SEEDREAM_MODELS = {
|
||||
"seedream 5.0 lite": "seedream-5-0-260128",
|
||||
"seedream-4-5-251128": "seedream-4-5-251128",
|
||||
"seedream-4-0-250828": "seedream-4-0-250828",
|
||||
}
|
||||
|
||||
# Long-running tasks endpoints(e.g., video)
|
||||
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
||||
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
||||
@@ -180,14 +186,13 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedreamNode",
|
||||
display_name="ByteDance Seedream 4.5",
|
||||
display_name="ByteDance Seedream 5.0",
|
||||
category="api node/image/ByteDance",
|
||||
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["seedream-4-5-251128", "seedream-4-0-250828"],
|
||||
tooltip="Model name",
|
||||
options=list(SEEDREAM_MODELS.keys()),
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@@ -198,7 +203,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Input image(s) for image-to-image generation. "
|
||||
"List of 1-10 images for single or multi-reference generation.",
|
||||
"Reference image(s) for single or multi-reference generation.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
@@ -210,8 +215,8 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
"width",
|
||||
default=2048,
|
||||
min=1024,
|
||||
max=4096,
|
||||
step=8,
|
||||
max=6240,
|
||||
step=2,
|
||||
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
|
||||
optional=True,
|
||||
),
|
||||
@@ -219,8 +224,8 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
"height",
|
||||
default=2048,
|
||||
min=1024,
|
||||
max=4096,
|
||||
step=8,
|
||||
max=4992,
|
||||
step=2,
|
||||
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
|
||||
optional=True,
|
||||
),
|
||||
@@ -283,7 +288,8 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$price := $contains(widgets.model, "seedream-4-5-251128") ? 0.04 : 0.03;
|
||||
$price := $contains(widgets.model, "5.0 lite") ? 0.035 :
|
||||
$contains(widgets.model, "4-5") ? 0.04 : 0.03;
|
||||
{
|
||||
"type":"usd",
|
||||
"usd": $price,
|
||||
@@ -309,6 +315,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
watermark: bool = False,
|
||||
fail_on_partial: bool = True,
|
||||
) -> IO.NodeOutput:
|
||||
model = SEEDREAM_MODELS[model]
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
w = h = None
|
||||
for label, tw, th in RECOMMENDED_PRESETS_SEEDREAM_4:
|
||||
@@ -318,15 +325,12 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
|
||||
if w is None or h is None:
|
||||
w, h = width, height
|
||||
if not (1024 <= w <= 4096) or not (1024 <= h <= 4096):
|
||||
raise ValueError(
|
||||
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
|
||||
)
|
||||
|
||||
out_num_pixels = w * h
|
||||
mp_provided = out_num_pixels / 1_000_000.0
|
||||
if "seedream-4-5" in model and out_num_pixels < 3686400:
|
||||
if ("seedream-4-5" in model or "seedream-5-0" in model) and out_num_pixels < 3686400:
|
||||
raise ValueError(
|
||||
f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, "
|
||||
f"Minimum image resolution for the selected model is 3.68MP, "
|
||||
f"but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
if "seedream-4-0" in model and out_num_pixels < 921600:
|
||||
@@ -334,9 +338,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
f"Minimum image resolution that the selected model can generate is 0.92MP, "
|
||||
f"but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
max_pixels = 10_404_496 if "seedream-5-0" in model else 16_777_216
|
||||
if out_num_pixels > max_pixels:
|
||||
raise ValueError(
|
||||
f"Maximum image resolution for the selected model is {max_pixels / 1_000_000:.2f}MP, "
|
||||
f"but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
n_input_images = get_number_of_images(image) if image is not None else 0
|
||||
if n_input_images > 10:
|
||||
raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.")
|
||||
max_num_of_images = 14 if model == "seedream-5-0-260128" else 10
|
||||
if n_input_images > max_num_of_images:
|
||||
raise ValueError(
|
||||
f"Maximum of {max_num_of_images} reference images are supported, but {n_input_images} received."
|
||||
)
|
||||
if sequential_image_generation == "auto" and n_input_images + max_images > 15:
|
||||
raise ValueError(
|
||||
"The maximum number of generated images plus the number of reference images cannot exceed 15."
|
||||
@@ -364,6 +377,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
sequential_image_generation=sequential_image_generation,
|
||||
sequential_image_generation_options=Seedream4Options(max_images=max_images),
|
||||
watermark=watermark,
|
||||
output_format="png" if model == "seedream-5-0-260128" else None,
|
||||
),
|
||||
)
|
||||
if len(response.data) == 1:
|
||||
|
||||
@@ -50,6 +50,7 @@ from comfy_api_nodes.apis import (
|
||||
)
|
||||
from comfy_api_nodes.apis.kling import (
|
||||
ImageToVideoWithAudioRequest,
|
||||
KlingAvatarRequest,
|
||||
MotionControlRequest,
|
||||
MultiPromptEntry,
|
||||
OmniImageParamImage,
|
||||
@@ -74,6 +75,7 @@ from comfy_api_nodes.util import (
|
||||
upload_image_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_audio_duration,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
@@ -3139,6 +3141,103 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
||||
|
||||
class KlingAvatarNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingAvatarNode",
|
||||
display_name="Kling Avatar 2.0",
|
||||
category="api node/video/Kling",
|
||||
description="Generate broadcast-style digital human videos from a single photo and an audio file.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Avatar reference image. "
|
||||
"Width and height must be at least 300px. Aspect ratio must be between 1:2.5 and 2.5:1.",
|
||||
),
|
||||
IO.Audio.Input(
|
||||
"sound_file",
|
||||
tooltip="Audio input. Must be between 2 and 300 seconds in duration.",
|
||||
),
|
||||
IO.Combo.Input("mode", options=["std", "pro"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
optional=True,
|
||||
tooltip="Optional prompt to define avatar actions, emotions, and camera movements.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["mode"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {"std": 0.056, "pro": 0.112};
|
||||
{"type":"usd","usd": $lookup($prices, widgets.mode), "format":{"suffix":"/second"}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
sound_file: Input.Audio,
|
||||
mode: str,
|
||||
seed: int,
|
||||
prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_image_dimensions(image, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(image, (1, 2.5), (2.5, 1))
|
||||
validate_audio_duration(sound_file, min_duration=2, max_duration=300)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/avatar/image2video", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
data=KlingAvatarRequest(
|
||||
image=await upload_image_to_comfyapi(cls, image),
|
||||
sound_file=await upload_audio_to_comfyapi(
|
||||
cls, sound_file, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg"
|
||||
),
|
||||
prompt=prompt or None,
|
||||
mode=mode,
|
||||
),
|
||||
)
|
||||
if response.code:
|
||||
raise RuntimeError(
|
||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/avatar/image2video/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
max_poll_attempts=800,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
||||
|
||||
class KlingExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@@ -3167,6 +3266,7 @@ class KlingExtension(ComfyExtension):
|
||||
MotionControl,
|
||||
KlingVideoNode,
|
||||
KlingFirstLastFrameNode,
|
||||
KlingAvatarNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import folder_paths
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import math
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
@@ -682,6 +683,172 @@ class ImageScaleToMaxDimension(IO.ComfyNode):
|
||||
upscale = execute # TODO: remove
|
||||
|
||||
|
||||
class SplitImageToTileList(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SplitImageToTileList",
|
||||
category="image/batch",
|
||||
search_aliases=["split image", "tile image", "slice image"],
|
||||
display_name="Split Image into List of Tiles",
|
||||
description="Splits an image into a batched list of tiles with a specified overlap.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input("tile_width", default=1024, min=64, max=MAX_RESOLUTION),
|
||||
IO.Int.Input("tile_height", default=1024, min=64, max=MAX_RESOLUTION),
|
||||
IO.Int.Input("overlap", default=128, min=0, max=4096),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(is_output_list=True),
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_grid_coords(width, height, tile_width, tile_height, overlap):
|
||||
coords = []
|
||||
stride_x = max(1, tile_width - overlap)
|
||||
stride_y = max(1, tile_height - overlap)
|
||||
|
||||
y = 0
|
||||
while y < height:
|
||||
x = 0
|
||||
y_end = min(y + tile_height, height)
|
||||
y_start = max(0, y_end - tile_height)
|
||||
|
||||
while x < width:
|
||||
x_end = min(x + tile_width, width)
|
||||
x_start = max(0, x_end - tile_width)
|
||||
|
||||
coords.append((x_start, y_start, x_end, y_end))
|
||||
|
||||
if x_end >= width:
|
||||
break
|
||||
x += stride_x
|
||||
|
||||
if y_end >= height:
|
||||
break
|
||||
y += stride_y
|
||||
|
||||
return coords
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, tile_width, tile_height, overlap):
|
||||
b, h, w, c = image.shape
|
||||
coords = cls.get_grid_coords(w, h, tile_width, tile_height, overlap)
|
||||
|
||||
output_list = []
|
||||
for (x_start, y_start, x_end, y_end) in coords:
|
||||
tile = image[:, y_start:y_end, x_start:x_end, :]
|
||||
output_list.append(tile)
|
||||
|
||||
return IO.NodeOutput(output_list)
|
||||
|
||||
|
||||
class ImageMergeTileList(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageMergeTileList",
|
||||
display_name="Merge List of Tiles to Image",
|
||||
category="image/batch",
|
||||
search_aliases=["split image", "tile image", "slice image"],
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
IO.Image.Input("image_list"),
|
||||
IO.Int.Input("final_width", default=1024, min=64, max=32768),
|
||||
IO.Int.Input("final_height", default=1024, min=64, max=32768),
|
||||
IO.Int.Input("overlap", default=128, min=0, max=4096),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(is_output_list=False),
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_grid_coords(width, height, tile_width, tile_height, overlap):
|
||||
coords = []
|
||||
stride_x = max(1, tile_width - overlap)
|
||||
stride_y = max(1, tile_height - overlap)
|
||||
|
||||
y = 0
|
||||
while y < height:
|
||||
x = 0
|
||||
y_end = min(y + tile_height, height)
|
||||
y_start = max(0, y_end - tile_height)
|
||||
|
||||
while x < width:
|
||||
x_end = min(x + tile_width, width)
|
||||
x_start = max(0, x_end - tile_width)
|
||||
|
||||
coords.append((x_start, y_start, x_end, y_end))
|
||||
|
||||
if x_end >= width:
|
||||
break
|
||||
x += stride_x
|
||||
|
||||
if y_end >= height:
|
||||
break
|
||||
y += stride_y
|
||||
|
||||
return coords
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image_list, final_width, final_height, overlap):
|
||||
w = final_width[0]
|
||||
h = final_height[0]
|
||||
ovlp = overlap[0]
|
||||
feather_str = 1.0
|
||||
|
||||
first_tile = image_list[0]
|
||||
b, t_h, t_w, c = first_tile.shape
|
||||
device = first_tile.device
|
||||
dtype = first_tile.dtype
|
||||
|
||||
coords = cls.get_grid_coords(w, h, t_w, t_h, ovlp)
|
||||
|
||||
canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype)
|
||||
weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype)
|
||||
|
||||
if ovlp > 0:
|
||||
y_w = torch.sin(math.pi * torch.linspace(0, 1, t_h, device=device, dtype=dtype))
|
||||
x_w = torch.sin(math.pi * torch.linspace(0, 1, t_w, device=device, dtype=dtype))
|
||||
y_w = torch.clamp(y_w, min=1e-5)
|
||||
x_w = torch.clamp(x_w, min=1e-5)
|
||||
|
||||
sine_mask = (y_w.unsqueeze(1) * x_w.unsqueeze(0)).unsqueeze(0).unsqueeze(-1)
|
||||
flat_mask = torch.ones_like(sine_mask)
|
||||
|
||||
weight_mask = torch.lerp(flat_mask, sine_mask, feather_str)
|
||||
else:
|
||||
weight_mask = torch.ones((1, t_h, t_w, 1), device=device, dtype=dtype)
|
||||
|
||||
for i, (x_start, y_start, x_end, y_end) in enumerate(coords):
|
||||
if i >= len(image_list):
|
||||
break
|
||||
|
||||
tile = image_list[i]
|
||||
|
||||
region_h = y_end - y_start
|
||||
region_w = x_end - x_start
|
||||
|
||||
real_h = min(region_h, tile.shape[1])
|
||||
real_w = min(region_w, tile.shape[2])
|
||||
|
||||
y_end_actual = y_start + real_h
|
||||
x_end_actual = x_start + real_w
|
||||
|
||||
tile_crop = tile[:, :real_h, :real_w, :]
|
||||
mask_crop = weight_mask[:, :real_h, :real_w, :]
|
||||
|
||||
canvas[:, y_start:y_end_actual, x_start:x_end_actual, :] += tile_crop * mask_crop
|
||||
weights[:, y_start:y_end_actual, x_start:x_end_actual, :] += mask_crop
|
||||
|
||||
weights[weights == 0] = 1.0
|
||||
merged_image = canvas / weights
|
||||
|
||||
return IO.NodeOutput(merged_image)
|
||||
|
||||
|
||||
class ImagesExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@@ -701,6 +868,8 @@ class ImagesExtension(ComfyExtension):
|
||||
ImageRotate,
|
||||
ImageFlip,
|
||||
ImageScaleToMaxDimension,
|
||||
SplitImageToTileList,
|
||||
ImageMergeTileList,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -134,6 +134,36 @@ class LTXVImgToVideoInplace(io.ComfyNode):
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0):
|
||||
"""Append a guide_attention_entry to both positive and negative conditioning.
|
||||
|
||||
Each entry tracks one guide reference for per-reference attention control.
|
||||
Entries are derived independently from each conditioning to avoid cross-contamination.
|
||||
"""
|
||||
new_entry = {
|
||||
"pre_filter_count": pre_filter_count,
|
||||
"strength": strength,
|
||||
"pixel_mask": None,
|
||||
"latent_shape": latent_shape,
|
||||
}
|
||||
results = []
|
||||
for cond in (positive, negative):
|
||||
# Read existing entries from this specific conditioning
|
||||
existing = []
|
||||
for t in cond:
|
||||
found = t[1].get("guide_attention_entries", None)
|
||||
if found is not None:
|
||||
existing = found
|
||||
break
|
||||
# Shallow copy and append (no deepcopy needed — entries contain
|
||||
# only scalars and None for pixel_mask at this call site).
|
||||
entries = [*existing, new_entry]
|
||||
results.append(node_helpers.conditioning_set_values(
|
||||
cond, {"guide_attention_entries": entries}
|
||||
))
|
||||
return results[0], results[1]
|
||||
|
||||
|
||||
def conditioning_get_any_value(conditioning, key, default=None):
|
||||
for t in conditioning:
|
||||
if key in t[1]:
|
||||
@@ -324,6 +354,13 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
scale_factors,
|
||||
)
|
||||
|
||||
# Track this guide for per-reference attention control.
|
||||
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
|
||||
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
|
||||
positive, negative = _append_guide_attention_entry(
|
||||
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
|
||||
)
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||
|
||||
generate = execute # TODO: remove
|
||||
@@ -359,8 +396,14 @@ class LTXVCropGuides(io.ComfyNode):
|
||||
latent_image = latent_image[:, :, :-num_keyframes]
|
||||
noise_mask = noise_mask[:, :, :-num_keyframes]
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
|
||||
positive = node_helpers.conditioning_set_values(positive, {
|
||||
"keyframe_idxs": None,
|
||||
"guide_attention_entries": None,
|
||||
})
|
||||
negative = node_helpers.conditioning_set_values(negative, {
|
||||
"keyframe_idxs": None,
|
||||
"guide_attention_entries": None,
|
||||
})
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ class ModelSamplingDiscrete:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],),
|
||||
"sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img", "img_to_img_flow"],),
|
||||
"zsnr": ("BOOLEAN", {"default": False, "advanced": True}),
|
||||
}}
|
||||
|
||||
@@ -76,6 +76,8 @@ class ModelSamplingDiscrete:
|
||||
sampling_type = comfy.model_sampling.X0
|
||||
elif sampling == "img_to_img":
|
||||
sampling_type = comfy.model_sampling.IMG_TO_IMG
|
||||
elif sampling == "img_to_img_flow":
|
||||
sampling_type = comfy.model_sampling.IMG_TO_IMG_FLOW
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
@@ -10,7 +10,7 @@ class NAGuidance(io.ComfyNode):
|
||||
node_id="NAGuidance",
|
||||
display_name="Normalized Attention Guidance",
|
||||
description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.",
|
||||
category="",
|
||||
category="advanced/guidance",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to apply NAG to."),
|
||||
|
||||
@@ -79,7 +79,6 @@ class Blur(io.ComfyNode):
|
||||
node_id="ImageBlur",
|
||||
display_name="Image Blur",
|
||||
category="image/postprocessing",
|
||||
essentials_category="Image Tools",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Int.Input("blur_radius", default=1, min=1, max=31, step=1),
|
||||
@@ -568,6 +567,7 @@ class BatchImagesNode(io.ComfyNode):
|
||||
node_id="BatchImagesNode",
|
||||
display_name="Batch Images",
|
||||
category="image",
|
||||
essentials_category="Image Tools",
|
||||
search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"],
|
||||
inputs=[
|
||||
io.Autogrow.Input("images", template=autogrow_template)
|
||||
|
||||
@@ -25,7 +25,7 @@ class TorchCompileModel(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, backend) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
m = model.clone(disable_dynamic=True)
|
||||
set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict})
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
@@ -147,7 +147,6 @@ class GetVideoComponents(io.ComfyNode):
|
||||
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
||||
display_name="Get Video Components",
|
||||
category="image/video",
|
||||
essentials_category="Video Tools",
|
||||
description="Extracts all components from a video: frames, audio, and framerate.",
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to extract components from."),
|
||||
@@ -218,6 +217,7 @@ class VideoSlice(io.ComfyNode):
|
||||
"start time",
|
||||
],
|
||||
category="image/video",
|
||||
essentials_category="Video Tools",
|
||||
inputs=[
|
||||
io.Video.Input("video"),
|
||||
io.Float.Input(
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.14.1"
|
||||
__version__ = "0.15.0"
|
||||
|
||||
1
nodes.py
1
nodes.py
@@ -1925,7 +1925,6 @@ class ImageInvert:
|
||||
|
||||
class ImageBatch:
|
||||
SEARCH_ALIASES = ["combine images", "merge images", "stack images"]
|
||||
ESSENTIALS_CATEGORY = "Image Tools"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.14.1"
|
||||
version = "0.15.0"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.39.14
|
||||
comfyui-workflow-templates==0.8.43
|
||||
comfyui-embedded-docs==0.4.1
|
||||
comfyui-frontend-package==1.39.19
|
||||
comfyui-workflow-templates==0.9.3
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
@@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.7
|
||||
comfy-aimdo>=0.2.0
|
||||
comfy-aimdo>=0.2.2
|
||||
requests
|
||||
|
||||
#non essential dependencies:
|
||||
|
||||
Reference in New Issue
Block a user