Compare commits

...

29 Commits

Author SHA1 Message Date
comfyanonymous
50549aa252 ComfyUI v0.16.2 2026-03-05 13:41:06 -05:00
comfyanonymous
1c3b651c0a Refactor. (#12794) 2026-03-05 13:35:56 -05:00
ComfyUI Wiki
5073da57ad chore: update workflow templates to v0.9.10 (#12793) 2026-03-05 10:22:38 -08:00
rattus
42e0e023ee ops: Handle CPU weight in VBAR caster (#12792)
This shouldn't happen but custom nodes gets there. Handle it as best
we can.
2026-03-05 10:22:17 -08:00
rattus
6481569ad4 comfy-aimdo 0.2.7 (#12791)
Comfy-aimdo 0.2.7 fixes a crash when a spurious cudaAsyncFree comes in
and would cause an infinite stack overflow (via detours hooks).

A lock is also introduced on the link list holding the free sections
to avoid any possibility of threaded miscellaneous cuda allocations
being the root cause.
2026-03-05 09:04:24 -08:00
comfyanonymous
6ef82a89b8 ComfyUI v0.16.1 2026-03-05 10:38:33 -05:00
ComfyUI Wiki
da29b797ce Update workflow templates to v0.9.8 (#12788) 2026-03-05 07:23:23 -08:00
Alexander Piskun
9cdfd7403b feat(api-nodes): enable Kling 3.0 Motion Control (#12785) 2026-03-05 07:12:38 -08:00
Alexander Piskun
bd21363563 feat(api-nodes-xAI): updated models, pricing, added features (#12756) 2026-03-05 04:29:39 -08:00
comfyanonymous
e04d0dbeb8 ComfyUI v0.16.0 2026-03-05 04:06:29 -05:00
ComfyUI Wiki
c8428541a6 chore: update workflow templates to v0.9.7 (#12780) 2026-03-05 03:58:25 -05:00
comfyanonymous
4941671b5a Fix cuda getting initialized in cpu mode. (#12779) 2026-03-05 02:39:51 -05:00
ComfyUI Wiki
c5fe8ace68 chore: update workflow templates to v0.9.6 (#12778) 2026-03-05 02:37:35 -05:00
comfyanonymous
f2ee7f2d36 Fix cublas ops on dynamic vram. (#12776) 2026-03-05 01:21:55 -05:00
comfyanonymous
43c64b6308 Support the LTXAV 2.3 model. (#12773) 2026-03-04 20:06:20 -05:00
comfyanonymous
ac4a943ff3 Initial load device should be cpu when using dynamic vram. (#12766) 2026-03-04 16:33:14 -05:00
rattus
8811db52db comfy-aimdo 0.2.6 (#12764)
Comfy Aimdo 0.2.6 fixes a GPU virtual address leak. This would manfiest
as an error after a number of workflow runs.
2026-03-04 12:12:37 -08:00
Jukka Seppänen
0a7446ade4 Pass tokens when loading text gen model for text generation (#12755)
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-03-04 08:59:56 -08:00
rattus
9b85cf9558 Comfy Aimdo 0.2.5 + Fix offload performance in DynamicVram (#12754)
* ops: dont unpin nothing

This was calling into aimdo in the none case (offloaded weight). Whats worse,
is aimdo syncs for unpinning an offloaded weight, as that is the corner case of
a weight getting evicted by its own use which does require a sync. But this
was heppening every offloaded weight causing slowdown.

* mp: fix get_free_memory policy

The ModelPatcherDynamic get_free_memory was deducting the model from
to try and estimate the conceptual free memory with doing any
offloading. This is kind of what the old memory_memory_required
was estimating in ModelPatcher load logic, however in practical
reality, between over-estimates and padding, the loader usually
underloaded models enough such that sampling could send CFG +/-
through together even when partially loaded.

So don't regress from the status quo and instead go all in on the
idea that offloading is less of an issue than debatching. Tell the
sampler it can use everything.
2026-03-04 07:49:13 -08:00
rattus
d531e3fb2a model_patcher: Improve dynamic offload heuristic (#12759)
Define a threshold below which a weight loading takes priority. This
actually makes the offload consistent with non-dynamic, because what
happens, is when non-dynamic fills ints to_load list, it will fill-up
any left-over pieces that could fix large weights with small weights
and load them, even though they were lower priority. This actually
improves performance because the timy weights dont cost any VRAM and
arent worth the control overhead of the DMA etc.
2026-03-04 07:47:44 -08:00
Arthur R Longbottom
eb011733b6 Fix VideoFromComponents.save_to crash when writing to BytesIO (#12683)
* Fix VideoFromComponents.save_to crash when writing to BytesIO

When `get_container_format()` or `get_stream_source()` is called on a
tensor-based video (VideoFromComponents), it calls `save_to(BytesIO())`.
Since BytesIO has no file extension, `av.open` can't infer the output
format and throws `ValueError: Could not determine output format`.

The sibling class `VideoFromFile` already handles this correctly via
`get_open_write_kwargs()`, which detects BytesIO and sets the format
explicitly. `VideoFromComponents` just never got the same treatment.

This surfaces when any downstream node validates the container format
of a tensor-based video, like TopazVideoEnhance or any node that calls
`validate_container_format_is_mp4()`.

Three-line fix in `comfy_api/latest/_input_impl/video_types.py`.

* Add docstring to save_to to satisfy CI coverage check
2026-03-04 00:29:00 -05:00
rattus
ac6513e142 DynamicVram: Add casting / fix torch Buffer weights (#12749)
* respect model dtype in non-comfy caster

* utils: factor out parent and name functionality of set_attr

* utils: implement set_attr_buffer for torch buffers

* ModelPatcherDynamic: Implement torch Buffer loading

If there is a buffer in dynamic - force load it.
2026-03-03 18:19:40 -08:00
Terry Jia
b6ddc590ed CURVE type (#12581)
* CURVE type

* fix: update typed wrapper unwrap keys to __type__ and __value__

* code improve

* code improve
2026-03-03 16:58:53 -08:00
comfyanonymous
f719a9d928 Adjust memory usage factor of zeta model. (#12746) 2026-03-03 17:35:22 -05:00
rattus
174fd6759d main: Load aimdo after logger is setup (#12743)
This was too early. Aimdo can use the logger in error paths and this
causes a rogue default init if aimdo has something to log.
2026-03-03 08:51:15 -08:00
rattus
09bcbddfcf ModelPatcherDynamic: Force load all non-comfy weights (#12739)
* model_management: Remove non-comfy dynamic _v caster

* Force pre-load non-comfy weights to GPU in ModelPatcherDynamic

Non-comfy weights may expect to be pre-cast to the target
device without in-model casting. Previously they were allocated in
the vbar with _v which required the _v fault path in cast_to.
Instead, back up the original CPU weight and move it directly to GPU
at load time.
2026-03-03 08:50:33 -08:00
xeinherjer
dff0a4a158 Fix VAEDecodeAudioTiled ignoring tile_size input (#12735) (#12738) 2026-03-02 20:17:51 -05:00
Lodestone
9ebee0a217 Feat: z-image pixel space (model still training atm) (#12709)
* draft zeta (z-image pixel space)

* revert gitignore

* model loaded and able to run however vector direction still wrong tho

* flip the vector direction to original again this time

* Move wrongly positioned Z image pixel space class

* inherit Radiance LatentFormat class

* Fix parameters in classes for Zeta x0 dino

* remove arbitrary nn.init instances

* Remove unused import of lru_cache

---------

Co-authored-by: silveroxides <ishimarukaito@gmail.com>
2026-03-02 19:43:47 -05:00
comfyanonymous
57dd6c1aad Support loading zeta chroma weights properly. (#12734) 2026-03-02 18:54:18 -05:00
30 changed files with 1527 additions and 273 deletions

View File

@@ -776,3 +776,10 @@ class ChromaRadiance(LatentFormat):
def process_out(self, latent):
return latent
class ZImagePixelSpace(ChromaRadiance):
"""Pixel-space latent format for ZImage DCT variant.
No VAE encoding/decoding — the model operates directly on RGB pixels.
"""
pass

View File

@@ -2,11 +2,16 @@ from typing import Tuple
import torch
import torch.nn as nn
from comfy.ldm.lightricks.model import (
ADALN_BASE_PARAMS_COUNT,
ADALN_CROSS_ATTN_PARAMS_COUNT,
CrossAttention,
FeedForward,
AdaLayerNormSingle,
PixArtAlphaTextProjection,
NormSingleLinearTextProjection,
LTXVModel,
apply_cross_attention_adaln,
compute_prompt_timestep,
)
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
@@ -87,6 +92,8 @@ class BasicAVTransformerBlock(nn.Module):
v_context_dim=None,
a_context_dim=None,
attn_precision=None,
apply_gated_attention=False,
cross_attention_adaln=False,
dtype=None,
device=None,
operations=None,
@@ -94,6 +101,7 @@ class BasicAVTransformerBlock(nn.Module):
super().__init__()
self.attn_precision = attn_precision
self.cross_attention_adaln = cross_attention_adaln
self.attn1 = CrossAttention(
query_dim=v_dim,
@@ -101,6 +109,7 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=vd_head,
context_dim=None,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -111,6 +120,7 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=ad_head,
context_dim=None,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -122,6 +132,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=v_heads,
dim_head=vd_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -132,6 +143,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -144,6 +156,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -156,6 +169,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -168,11 +182,16 @@ class BasicAVTransformerBlock(nn.Module):
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
)
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype))
self.audio_scale_shift_table = nn.Parameter(
torch.empty(6, a_dim, device=device, dtype=dtype)
torch.empty(num_ada_params, a_dim, device=device, dtype=dtype)
)
if cross_attention_adaln:
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype))
self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype))
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
torch.empty(5, a_dim, device=device, dtype=dtype)
)
@@ -215,10 +234,30 @@ class BasicAVTransformerBlock(nn.Module):
return (*scale_shift_ada_values, *gate_ada_values)
def _apply_text_cross_attention(
self, x, context, attn, scale_shift_table, prompt_scale_shift_table,
timestep, prompt_timestep, attention_mask, transformer_options,
):
"""Apply text cross-attention, with optional ADaLN modulation."""
if self.cross_attention_adaln:
shift_q, scale_q, gate = self.get_ada_values(
scale_shift_table, x.shape[0], timestep, slice(6, 9)
)
return apply_cross_attention_adaln(
x, context, attn, shift_q, scale_q, gate,
prompt_scale_shift_table, prompt_timestep,
attention_mask, transformer_options,
)
return attn(
comfy.ldm.common_dit.rms_norm(x), context=context,
mask=attention_mask, transformer_options=transformer_options,
)
def forward(
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
v_prompt_timestep=None, a_prompt_timestep=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
run_vx = transformer_options.get("run_vx", True)
run_ax = transformer_options.get("run_ax", True)
@@ -240,7 +279,11 @@ class BasicAVTransformerBlock(nn.Module):
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
vx.addcmul_(attn1_out, vgate_msa)
del vgate_msa, attn1_out
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
vx.add_(self._apply_text_cross_attention(
vx, v_context, self.attn2, self.scale_shift_table,
getattr(self, 'prompt_scale_shift_table', None),
v_timestep, v_prompt_timestep, attention_mask, transformer_options,)
)
# audio
if run_ax:
@@ -254,7 +297,11 @@ class BasicAVTransformerBlock(nn.Module):
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
ax.addcmul_(attn1_out, agate_msa)
del agate_msa, attn1_out
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
ax.add_(self._apply_text_cross_attention(
ax, a_context, self.audio_attn2, self.audio_scale_shift_table,
getattr(self, 'audio_prompt_scale_shift_table', None),
a_timestep, a_prompt_timestep, attention_mask, transformer_options,)
)
# video - audio cross attention.
if run_a2v or run_v2a:
@@ -351,6 +398,9 @@ class LTXAVModel(LTXVModel):
use_middle_indices_grid=False,
timestep_scale_multiplier=1000.0,
av_ca_timestep_scale_multiplier=1.0,
apply_gated_attention=False,
caption_proj_before_connector=False,
cross_attention_adaln=False,
dtype=None,
device=None,
operations=None,
@@ -362,6 +412,7 @@ class LTXAVModel(LTXVModel):
self.audio_attention_head_dim = audio_attention_head_dim
self.audio_num_attention_heads = audio_num_attention_heads
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
self.apply_gated_attention = apply_gated_attention
# Calculate audio dimensions
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
@@ -386,6 +437,8 @@ class LTXAVModel(LTXVModel):
vae_scale_factors=vae_scale_factors,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
caption_proj_before_connector=caption_proj_before_connector,
cross_attention_adaln=cross_attention_adaln,
dtype=dtype,
device=device,
operations=operations,
@@ -400,14 +453,28 @@ class LTXAVModel(LTXVModel):
)
# Audio-specific AdaLN
audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
self.audio_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
embedding_coefficient=audio_embedding_coefficient,
use_additional_conditions=False,
dtype=dtype,
device=device,
operations=self.operations,
)
if self.cross_attention_adaln:
self.audio_prompt_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
embedding_coefficient=2,
use_additional_conditions=False,
dtype=dtype,
device=device,
operations=self.operations,
)
else:
self.audio_prompt_adaln_single = None
num_scale_shift_values = 4
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
self.inner_dim,
@@ -443,35 +510,73 @@ class LTXAVModel(LTXVModel):
)
# Audio caption projection
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.audio_inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
if self.caption_proj_before_connector:
if self.caption_projection_first_linear:
self.audio_caption_projection = NormSingleLinearTextProjection(
in_features=self.caption_channels,
hidden_size=self.audio_inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
else:
self.audio_caption_projection = lambda a: a
else:
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.audio_inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
connector_split_rope = kwargs.get("rope_type", "split") == "split"
connector_gated_attention = kwargs.get("connector_apply_gated_attention", False)
attention_head_dim = kwargs.get("connector_attention_head_dim", 128)
num_attention_heads = kwargs.get("connector_num_attention_heads", 30)
num_layers = kwargs.get("connector_num_layers", 2)
self.audio_embeddings_connector = Embeddings1DConnector(
split_rope=True,
attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim),
num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads),
num_layers=num_layers,
split_rope=connector_split_rope,
double_precision_rope=True,
apply_gated_attention=connector_gated_attention,
dtype=dtype,
device=device,
operations=self.operations,
)
self.video_embeddings_connector = Embeddings1DConnector(
split_rope=True,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
num_layers=num_layers,
split_rope=connector_split_rope,
double_precision_rope=True,
apply_gated_attention=connector_gated_attention,
dtype=dtype,
device=device,
operations=self.operations,
)
def preprocess_text_embeds(self, context):
if context.shape[-1] == self.caption_channels * 2:
return context
out_vid = self.video_embeddings_connector(context)[0]
out_audio = self.audio_embeddings_connector(context)[0]
def preprocess_text_embeds(self, context, unprocessed=False):
# LTXv2 fully processed context has dimension of self.caption_channels * 2
# LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim
if not unprocessed:
if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2):
return context
if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim:
context_vid = context[:, :, :self.cross_attention_dim]
context_audio = context[:, :, self.cross_attention_dim:]
else:
context_vid = context
context_audio = context
if self.caption_proj_before_connector:
context_vid = self.caption_projection(context_vid)
context_audio = self.audio_caption_projection(context_audio)
out_vid = self.video_embeddings_connector(context_vid)[0]
out_audio = self.audio_embeddings_connector(context_audio)[0]
return torch.concat((out_vid, out_audio), dim=-1)
def _init_transformer_blocks(self, device, dtype, **kwargs):
@@ -487,6 +592,8 @@ class LTXAVModel(LTXVModel):
ad_head=self.audio_attention_head_dim,
v_context_dim=self.cross_attention_dim,
a_context_dim=self.audio_cross_attention_dim,
apply_gated_attention=self.apply_gated_attention,
cross_attention_adaln=self.cross_attention_adaln,
dtype=dtype,
device=device,
operations=self.operations,
@@ -608,6 +715,10 @@ class LTXAVModel(LTXVModel):
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
v_prompt_timestep = compute_prompt_timestep(
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
)
# Prepare audio timestep
a_timestep = kwargs.get("a_timestep")
if a_timestep is not None:
@@ -618,25 +729,25 @@ class LTXAVModel(LTXVModel):
# Cross-attention timesteps - compress these too
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
a_timestep_flat,
timestep.max().expand_as(a_timestep_flat),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
timestep_flat,
a_timestep.max().expand_as(timestep_flat),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
timestep_flat * av_ca_factor,
a_timestep.max().expand_as(timestep_flat) * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
a_timestep_flat * av_ca_factor,
timestep.max().expand_as(a_timestep_flat) * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
@@ -660,29 +771,40 @@ class LTXAVModel(LTXVModel):
# Audio timesteps
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
a_prompt_timestep = compute_prompt_timestep(
self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype
)
else:
a_timestep = timestep_scaled
a_embedded_timestep = kwargs.get("embedded_timestep")
cross_av_timestep_ss = []
a_prompt_timestep = None
return [v_timestep, a_timestep, cross_av_timestep_ss], [
return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [
v_embedded_timestep,
a_embedded_timestep,
]
], None
def _prepare_context(self, context, batch_size, x, attention_mask=None):
vx = x[0]
ax = x[1]
video_dim = vx.shape[-1]
audio_dim = ax.shape[-1]
v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim
a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim
v_context, a_context = torch.split(
context, int(context.shape[-1] / 2), len(context.shape) - 1
context, [v_context_dim, a_context_dim], len(context.shape) - 1
)
v_context, attention_mask = super()._prepare_context(
v_context, batch_size, vx, attention_mask
)
if self.audio_caption_projection is not None:
if self.caption_proj_before_connector is False:
a_context = self.audio_caption_projection(a_context)
a_context = a_context.view(batch_size, -1, ax.shape[-1])
a_context = a_context.view(batch_size, -1, audio_dim)
return [v_context, a_context], attention_mask
@@ -744,6 +866,9 @@ class LTXAVModel(LTXVModel):
av_ca_v2a_gate_noise_timestep,
) = timestep[2]
v_prompt_timestep = timestep[3]
a_prompt_timestep = timestep[4]
"""Process transformer blocks for LTXAV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
@@ -771,6 +896,8 @@ class LTXAVModel(LTXVModel):
a_cross_gate_timestep=args["a_cross_gate_timestep"],
transformer_options=args["transformer_options"],
self_attention_mask=args.get("self_attention_mask"),
v_prompt_timestep=args.get("v_prompt_timestep"),
a_prompt_timestep=args.get("a_prompt_timestep"),
)
return out
@@ -792,6 +919,8 @@ class LTXAVModel(LTXVModel):
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
"transformer_options": transformer_options,
"self_attention_mask": self_attention_mask,
"v_prompt_timestep": v_prompt_timestep,
"a_prompt_timestep": a_prompt_timestep,
},
{"original_block": block_wrap},
)
@@ -814,6 +943,8 @@ class LTXAVModel(LTXVModel):
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
v_prompt_timestep=v_prompt_timestep,
a_prompt_timestep=a_prompt_timestep,
)
return [vx, ax]

View File

@@ -50,6 +50,7 @@ class BasicTransformerBlock1D(nn.Module):
d_head,
context_dim=None,
attn_precision=None,
apply_gated_attention=False,
dtype=None,
device=None,
operations=None,
@@ -63,6 +64,7 @@ class BasicTransformerBlock1D(nn.Module):
heads=n_heads,
dim_head=d_head,
context_dim=None,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@@ -121,6 +123,7 @@ class Embeddings1DConnector(nn.Module):
positional_embedding_max_pos=[4096],
causal_temporal_positioning=False,
num_learnable_registers: Optional[int] = 128,
apply_gated_attention=False,
dtype=None,
device=None,
operations=None,
@@ -145,6 +148,7 @@ class Embeddings1DConnector(nn.Module):
num_attention_heads,
attention_head_dim,
context_dim=cross_attention_dim,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,

View File

@@ -275,6 +275,30 @@ class PixArtAlphaTextProjection(nn.Module):
return hidden_states
class NormSingleLinearTextProjection(nn.Module):
"""Text projection for 20B models - single linear with RMSNorm (no activation)."""
def __init__(
self, in_features, hidden_size, dtype=None, device=None, operations=None
):
super().__init__()
if operations is None:
operations = comfy.ops.disable_weight_init
self.in_norm = operations.RMSNorm(
in_features, eps=1e-6, elementwise_affine=False
)
self.linear_1 = operations.Linear(
in_features, hidden_size, bias=True, dtype=dtype, device=device
)
self.hidden_size = hidden_size
self.in_features = in_features
def forward(self, caption):
caption = self.in_norm(caption)
caption = caption * (self.hidden_size / self.in_features) ** 0.5
return self.linear_1(caption)
class GELU_approx(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
super().__init__()
@@ -343,6 +367,7 @@ class CrossAttention(nn.Module):
dim_head=64,
dropout=0.0,
attn_precision=None,
apply_gated_attention=False,
dtype=None,
device=None,
operations=None,
@@ -362,6 +387,12 @@ class CrossAttention(nn.Module):
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
# Optional per-head gating
if apply_gated_attention:
self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device)
else:
self.to_gate_logits = None
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
)
@@ -383,16 +414,30 @@ class CrossAttention(nn.Module):
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
# Apply per-head gating if enabled
if self.to_gate_logits is not None:
gate_logits = self.to_gate_logits(x) # (B, T, H)
b, t, _ = out.shape
out = out.view(b, t, self.heads, self.dim_head)
gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity
out = out * gates.unsqueeze(-1)
out = out.view(b, t, self.heads * self.dim_head)
return self.to_out(out)
# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate)
ADALN_BASE_PARAMS_COUNT = 6
ADALN_CROSS_ATTN_PARAMS_COUNT = 9
class BasicTransformerBlock(nn.Module):
def __init__(
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None
):
super().__init__()
self.attn_precision = attn_precision
self.cross_attention_adaln = cross_attention_adaln
self.attn1 = CrossAttention(
query_dim=dim,
heads=n_heads,
@@ -416,18 +461,25 @@ class BasicTransformerBlock(nn.Module):
operations=operations,
)
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
if cross_attention_adaln:
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype))
attn1_input = comfy.ldm.common_dit.rms_norm(x)
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options)
x.addcmul_(attn1_input, gate_msa)
del attn1_input
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2)
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa
if self.cross_attention_adaln:
shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2)
x += apply_cross_attention_adaln(
x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca,
self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options,
)
else:
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
y = comfy.ldm.common_dit.rms_norm(x)
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
@@ -435,6 +487,47 @@ class BasicTransformerBlock(nn.Module):
return x
def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype):
"""Compute a single global prompt timestep for cross-attention ADaLN.
Uses the max across tokens (matching JAX max_per_segment) and broadcasts
over text tokens. Returns None when *adaln_module* is None.
"""
if adaln_module is None:
return None
ts_input = (
timestep_scaled.max(dim=1, keepdim=True).values.flatten()
if timestep_scaled.dim() > 1
else timestep_scaled.flatten()
)
prompt_ts, _ = adaln_module(
ts_input,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1])
def apply_cross_attention_adaln(
x, context, attn, q_shift, q_scale, q_gate,
prompt_scale_shift_table, prompt_timestep,
attention_mask=None, transformer_options={},
):
"""Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV).
Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so
that both regular tensors and CompressedTimestep are supported.
"""
batch_size = x.shape[0]
shift_kv, scale_kv = (
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
).unbind(dim=2)
attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate
def get_fractional_positions(indices_grid, max_pos):
n_pos_dims = indices_grid.shape[1]
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
@@ -556,6 +649,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
vae_scale_factors: tuple = (8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier = 1000.0,
caption_proj_before_connector=False,
cross_attention_adaln=False,
caption_projection_first_linear=True,
dtype=None,
device=None,
operations=None,
@@ -582,6 +678,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
self.causal_temporal_positioning = causal_temporal_positioning
self.operations = operations
self.timestep_scale_multiplier = timestep_scale_multiplier
self.caption_proj_before_connector = caption_proj_before_connector
self.cross_attention_adaln = cross_attention_adaln
self.caption_projection_first_linear = caption_projection_first_linear
# Common dimensions
self.inner_dim = num_attention_heads * attention_head_dim
@@ -609,17 +708,37 @@ class LTXBaseModel(torch.nn.Module, ABC):
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
)
embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
self.adaln_single = AdaLayerNormSingle(
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
)
self.caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
if self.cross_attention_adaln:
self.prompt_adaln_single = AdaLayerNormSingle(
self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
)
else:
self.prompt_adaln_single = None
if self.caption_proj_before_connector:
if self.caption_projection_first_linear:
self.caption_projection = NormSingleLinearTextProjection(
in_features=self.caption_channels,
hidden_size=self.inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
else:
self.caption_projection = lambda a: a
else:
self.caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
@abstractmethod
def _init_model_components(self, device, dtype, **kwargs):
@@ -665,9 +784,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
if grid_mask is not None:
timestep = timestep[:, grid_mask]
timestep = timestep * self.timestep_scale_multiplier
timestep_scaled = timestep * self.timestep_scale_multiplier
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
timestep_scaled.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
@@ -677,14 +796,18 @@ class LTXBaseModel(torch.nn.Module, ABC):
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
return timestep, embedded_timestep
prompt_timestep = compute_prompt_timestep(
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
)
return timestep, embedded_timestep, prompt_timestep
def _prepare_context(self, context, batch_size, x, attention_mask=None):
"""Prepare context for transformer blocks."""
if self.caption_projection is not None:
if self.caption_proj_before_connector is False:
context = self.caption_projection(context)
context = context.view(batch_size, -1, x.shape[-1])
context = context.view(batch_size, -1, x.shape[-1])
return context, attention_mask
def _precompute_freqs_cis(
@@ -792,7 +915,8 @@ class LTXBaseModel(torch.nn.Module, ABC):
merged_args.update(additional_args)
# Prepare timestep and context
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
merged_args["prompt_timestep"] = prompt_timestep
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
# Prepare attention mask and positional embeddings
@@ -833,7 +957,9 @@ class LTXVModel(LTXBaseModel):
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier = 1000.0,
timestep_scale_multiplier=1000.0,
caption_proj_before_connector=False,
cross_attention_adaln=False,
dtype=None,
device=None,
operations=None,
@@ -852,6 +978,8 @@ class LTXVModel(LTXBaseModel):
vae_scale_factors=vae_scale_factors,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
caption_proj_before_connector=caption_proj_before_connector,
cross_attention_adaln=cross_attention_adaln,
dtype=dtype,
device=device,
operations=operations,
@@ -860,7 +988,6 @@ class LTXVModel(LTXBaseModel):
def _init_model_components(self, device, dtype, **kwargs):
"""Initialize LTXV-specific components."""
# No additional components needed for LTXV beyond base class
pass
def _init_transformer_blocks(self, device, dtype, **kwargs):
@@ -872,6 +999,7 @@ class LTXVModel(LTXBaseModel):
self.num_attention_heads,
self.attention_head_dim,
context_dim=self.cross_attention_dim,
cross_attention_adaln=self.cross_attention_adaln,
dtype=dtype,
device=device,
operations=self.operations,
@@ -1149,16 +1277,17 @@ class LTXVModel(LTXBaseModel):
"""Process transformer blocks for LTXV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
prompt_timestep = kwargs.get("prompt_timestep", None)
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"))
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep"))
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
@@ -1169,6 +1298,7 @@ class LTXVModel(LTXBaseModel):
pe=pe,
transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
prompt_timestep=prompt_timestep,
)
return x

View File

@@ -13,7 +13,7 @@ from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
CausalityAxis,
CausalAudioAutoencoder,
)
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE
LATENT_DOWNSAMPLE_FACTOR = 4
@@ -141,7 +141,10 @@ class AudioVAE(torch.nn.Module):
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
self.vocoder = Vocoder(config=component_config.vocoder)
if "bwe" in component_config.vocoder:
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
else:
self.vocoder = Vocoder(config=component_config.vocoder)
self.autoencoder.load_state_dict(vae_sd, strict=False)
self.vocoder.load_state_dict(vocoder_sd, strict=False)

View File

@@ -822,26 +822,23 @@ class CausalAudioAutoencoder(nn.Module):
super().__init__()
if config is None:
config = self._guess_config()
config = self.get_default_config()
# Extract encoder and decoder configs from the new format
model_config = config.get("model", {}).get("params", {})
variables_config = config.get("variables", {})
self.sampling_rate = variables_config.get(
"sampling_rate",
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
self.sampling_rate = model_config.get(
"sampling_rate", config.get("sampling_rate", 16000)
)
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
decoder_config = model_config.get("decoder", encoder_config)
# Load mel spectrogram parameters
self.mel_bins = encoder_config.get("mel_bins", 64)
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
# Store causality configuration at VAE level (not just in encoder internals)
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value)
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
@@ -850,44 +847,38 @@ class CausalAudioAutoencoder(nn.Module):
self.per_channel_statistics = processor()
def _guess_config(self):
encoder_config = {
# Required parameters - based on ltx-video-av-1679000 model metadata
"ch": 128,
"out_ch": 8,
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
"num_res_blocks": 2,
"attn_resolutions": [], # Based on metadata: empty list, no attention
"dropout": 0.0,
"resamp_with_conv": True,
"in_channels": 2, # stereo
"resolution": 256,
"z_channels": 8,
def get_default_config(self):
ddconfig = {
"double_z": True,
"attn_type": "vanilla",
"mid_block_add_attention": False, # Based on metadata: false
"mel_bins": 64,
"z_channels": 8,
"resolution": 256,
"downsample_time": False,
"in_channels": 2,
"out_ch": 2,
"ch": 128,
"ch_mult": [1, 2, 4],
"num_res_blocks": 2,
"attn_resolutions": [],
"dropout": 0.0,
"mid_block_add_attention": False,
"norm_type": "pixel",
"causality_axis": "height", # Based on metadata
"mel_bins": 64, # Based on metadata: mel_bins = 64
}
decoder_config = {
# Inherits encoder config, can override specific params
**encoder_config,
"out_ch": 2, # Stereo audio output (2 channels)
"give_pre_end": False,
"tanh_out": False,
"causality_axis": "height",
}
config = {
"_class_name": "CausalAudioAutoencoder",
"sampling_rate": 16000,
"model": {
"params": {
"encoder": encoder_config,
"decoder": decoder_config,
"ddconfig": ddconfig,
"sampling_rate": 16000,
}
},
"preprocessing": {
"stft": {
"filter_length": 1024,
"hop_length": 160,
},
},
}
return config

View File

@@ -15,6 +15,9 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init
def in_meta_context():
return torch.device("meta") == torch.empty(0).device
def mark_conv3d_ended(module):
tid = threading.get_ident()
for _, m in module.named_modules():
@@ -350,6 +353,10 @@ class Decoder(nn.Module):
output_channel = output_channel * block_params.get("multiplier", 2)
if block_name == "compress_all":
output_channel = output_channel * block_params.get("multiplier", 1)
if block_name == "compress_space":
output_channel = output_channel * block_params.get("multiplier", 1)
if block_name == "compress_time":
output_channel = output_channel * block_params.get("multiplier", 1)
self.conv_in = make_conv_nd(
dims,
@@ -395,17 +402,21 @@ class Decoder(nn.Module):
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
output_channel = output_channel // block_params.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(2, 1, 1),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
output_channel = output_channel // block_params.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(1, 2, 2),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
@@ -455,6 +466,15 @@ class Decoder(nn.Module):
output_channel * 2, 0, operations=ops,
)
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
else:
self.register_buffer(
"last_scale_shift_table",
torch.tensor(
[0.0, 0.0],
device="cpu" if in_meta_context() else None
).unsqueeze(1).expand(2, output_channel),
persistent=False,
)
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
@@ -883,6 +903,15 @@ class ResnetBlock3D(nn.Module):
self.scale_shift_table = nn.Parameter(
torch.randn(4, in_channels) / in_channels**0.5
)
else:
self.register_buffer(
"scale_shift_table",
torch.tensor(
[0.0, 0.0, 0.0, 0.0],
device="cpu" if in_meta_context() else None
).unsqueeze(1).expand(4, in_channels),
persistent=False,
)
self.temporal_cache_state={}
@@ -1012,9 +1041,6 @@ class processor(nn.Module):
super().__init__()
self.register_buffer("std-of-means", torch.empty(128))
self.register_buffer("mean-of-means", torch.empty(128))
self.register_buffer("mean-of-stds", torch.empty(128))
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
self.register_buffer("channel", torch.empty(128))
def un_normalize(self, x):
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
@@ -1027,9 +1053,12 @@ class VideoVAE(nn.Module):
super().__init__()
if config is None:
config = self.guess_config(version)
config = self.get_default_config(version)
self.config = config
self.timestep_conditioning = config.get("timestep_conditioning", False)
self.decode_noise_scale = config.get("decode_noise_scale", 0.025)
self.decode_timestep = config.get("decode_timestep", 0.05)
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
@@ -1044,6 +1073,7 @@ class VideoVAE(nn.Module):
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
base_channels=config.get("encoder_base_channels", 128),
)
self.decoder = Decoder(
@@ -1051,6 +1081,7 @@ class VideoVAE(nn.Module):
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
base_channels=config.get("decoder_base_channels", 128),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
@@ -1060,7 +1091,7 @@ class VideoVAE(nn.Module):
self.per_channel_statistics = processor()
def guess_config(self, version):
def get_default_config(self, version):
if version == 0:
config = {
"_class_name": "CausalVideoAutoencoder",
@@ -1167,8 +1198,7 @@ class VideoVAE(nn.Module):
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)
def decode(self, x, timestep=0.05, noise_scale=0.025):
def decode(self, x):
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)

View File

@@ -3,6 +3,7 @@ import torch.nn.functional as F
import torch.nn as nn
import comfy.ops
import numpy as np
import math
ops = comfy.ops.disable_weight_init
@@ -12,6 +13,307 @@ def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
# ---------------------------------------------------------------------------
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
# Adopted from https://github.com/NVIDIA/BigVGAN
# ---------------------------------------------------------------------------
def _sinc(x: torch.Tensor):
return torch.where(
x == 0,
torch.tensor(1.0, device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x,
)
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
even = kernel_size % 2 == 0
half_size = kernel_size // 2
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.0:
beta = 0.1102 * (A - 8.7)
elif A >= 21.0:
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
else:
beta = 0.0
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
if even:
time = torch.arange(-half_size, half_size) + 0.5
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)
return filter
class LowPassFilter1d(nn.Module):
def __init__(
self,
cutoff=0.5,
half_width=0.6,
stride=1,
padding=True,
padding_mode="replicate",
kernel_size=12,
):
super().__init__()
if cutoff < -0.0:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = kernel_size % 2 == 0
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)
def forward(self, x):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"):
super().__init__()
self.ratio = ratio
self.stride = ratio
if window_type == "hann":
# Hann-windowed sinc filter — identical to torchaudio.functional.resample
# with its default parameters (rolloff=0.99, lowpass_filter_width=6).
# Uses replicate boundary padding, matching the reference resampler exactly.
rolloff = 0.99
lowpass_filter_width = 6
width = math.ceil(lowpass_filter_width / rolloff)
self.kernel_size = 2 * width * ratio + 1
self.pad = width
self.pad_left = 2 * width * ratio
self.pad_right = self.kernel_size - ratio
t = (torch.arange(self.kernel_size) / ratio - width) * rolloff
t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width)
window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2
filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1)
else:
# Kaiser-windowed sinc filter (BigVGAN default).
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = (
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
)
filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
)
self.register_buffer("filter", filter, persistent=persistent)
def forward(self, x):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode="replicate")
x = self.ratio * F.conv_transpose1d(
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
)
x = x[..., self.pad_left : -self.pad_right]
return x
class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.lowpass = LowPassFilter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size,
)
def forward(self, x):
return self.lowpass(x)
class Activation1d(nn.Module):
def __init__(
self,
activation,
up_ratio=2,
down_ratio=2,
up_kernel_size=12,
down_kernel_size=12,
):
super().__init__()
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x
# ---------------------------------------------------------------------------
# BigVGAN v2 activations (Snake / SnakeBeta)
# ---------------------------------------------------------------------------
class Snake(nn.Module):
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
):
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = nn.Parameter(
torch.zeros(in_features)
if alpha_logscale
else torch.ones(in_features) * alpha
)
self.alpha.requires_grad = alpha_trainable
self.eps = 1e-9
def forward(self, x):
a = self.alpha.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
a = torch.exp(a)
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
class SnakeBeta(nn.Module):
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
):
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = nn.Parameter(
torch.zeros(in_features)
if alpha_logscale
else torch.ones(in_features) * alpha
)
self.alpha.requires_grad = alpha_trainable
self.beta = nn.Parameter(
torch.zeros(in_features)
if alpha_logscale
else torch.ones(in_features) * alpha
)
self.beta.requires_grad = alpha_trainable
self.eps = 1e-9
def forward(self, x):
a = self.alpha.unsqueeze(0).unsqueeze(-1)
b = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
a = torch.exp(a)
b = torch.exp(b)
return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2)
# ---------------------------------------------------------------------------
# BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity)
# ---------------------------------------------------------------------------
class AMPBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"):
super().__init__()
act_cls = SnakeBeta if activation == "snakebeta" else Snake
self.convs1 = nn.ModuleList(
[
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
),
]
)
self.convs2 = nn.ModuleList(
[
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
]
)
self.acts1 = nn.ModuleList(
[Activation1d(act_cls(channels)) for _ in range(len(self.convs1))]
)
self.acts2 = nn.ModuleList(
[Activation1d(act_cls(channels)) for _ in range(len(self.convs2))]
)
def forward(self, x):
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = x + xt
return x
# ---------------------------------------------------------------------------
# HiFi-GAN residual blocks
# ---------------------------------------------------------------------------
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
@@ -119,6 +421,7 @@ class Vocoder(torch.nn.Module):
"""
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1").
"""
def __init__(self, config=None):
@@ -128,19 +431,39 @@ class Vocoder(torch.nn.Module):
config = self.get_default_config()
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2])
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4])
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
stereo = config.get("stereo", True)
resblock = config.get("resblock", "1")
activation = config.get("activation", "snake")
use_bias_at_final = config.get("use_bias_at_final", True)
# "output_sample_rate" is not present in recent checkpoint configs.
# When absent (None), AudioVAE.output_sample_rate computes it as:
# sample_rate * vocoder.upsample_factor / mel_hop_length
# where upsample_factor = product of all upsample stride lengths,
# and mel_hop_length is loaded from the autoencoder config at
# preprocessing.stft.hop_length (see CausalAudioAutoencoder).
self.output_sample_rate = config.get("output_sample_rate")
self.resblock = config.get("resblock", "1")
self.use_tanh_at_final = config.get("use_tanh_at_final", True)
self.apply_final_activation = config.get("apply_final_activation", True)
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
in_channels = 128 if stereo else 64
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
if self.resblock == "1":
resblock_cls = ResBlock1
elif self.resblock == "2":
resblock_cls = ResBlock2
elif self.resblock == "AMP1":
resblock_cls = AMPBlock1
else:
raise ValueError(f"Unknown resblock type: {self.resblock}")
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
@@ -157,25 +480,40 @@ class Vocoder(torch.nn.Module):
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock_class(ch, k, d))
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
if self.resblock == "AMP1":
self.resblocks.append(resblock_cls(ch, k, d, activation=activation))
else:
self.resblocks.append(resblock_cls(ch, k, d))
out_channels = 2 if stereo else 1
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
if self.resblock == "AMP1":
act_cls = SnakeBeta if activation == "snakebeta" else Snake
self.act_post = Activation1d(act_cls(ch))
else:
self.act_post = nn.LeakyReLU()
self.conv_post = ops.Conv1d(
ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final
)
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
def get_default_config(self):
"""Generate default configuration for the vocoder."""
config = {
"resblock_kernel_sizes": [3, 7, 11],
"upsample_rates": [6, 5, 2, 2, 2],
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
"upsample_rates": [5, 4, 2, 2, 2],
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"upsample_initial_channel": 1024,
"stereo": True,
"resblock": "1",
"activation": "snake",
"use_bias_at_final": True,
"use_tanh_at_final": True,
}
return config
@@ -196,8 +534,10 @@ class Vocoder(torch.nn.Module):
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
if self.resblock != "AMP1":
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
@@ -206,8 +546,167 @@ class Vocoder(torch.nn.Module):
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.act_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
if self.apply_final_activation:
if self.use_tanh_at_final:
x = torch.tanh(x)
else:
x = torch.clamp(x, -1, 1)
return x
class _STFTFn(nn.Module):
"""Implements STFT as a convolution with precomputed DFT × Hann-window bases.
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
bfloat16 bases from training ensures the mel values fed to the BWE generator are
bit-identical to what it was trained on.
"""
def __init__(self, filter_length: int, hop_length: int, win_length: int):
super().__init__()
self.hop_length = hop_length
self.win_length = win_length
n_freqs = filter_length // 2 + 1
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute magnitude and phase spectrogram from a batch of waveforms.
Applies causal (left-only) padding of win_length - hop_length samples so that
each output frame depends only on past and present input — no lookahead.
The STFT is computed by convolving the padded signal with forward_basis.
Args:
y: Waveform tensor of shape (B, T).
Returns:
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
Computed in float32 for numerical stability, then cast back to
the input dtype.
"""
if y.dim() == 2:
y = y.unsqueeze(1) # (B, 1, T)
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
y = F.pad(y, (left_pad, 0))
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
n_freqs = spec.shape[1] // 2
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
magnitude = torch.sqrt(real ** 2 + imag ** 2)
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
return magnitude, phase
class MelSTFT(nn.Module):
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
"""
def __init__(
self,
filter_length: int,
hop_length: int,
win_length: int,
n_mel_channels: int,
sampling_rate: int,
mel_fmin: float,
mel_fmax: float,
):
super().__init__()
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
n_freqs = filter_length // 2 + 1
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
def mel_spectrogram(
self, y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute log-mel spectrogram and auxiliary spectral quantities.
Args:
y: Waveform tensor of shape (B, T).
Returns:
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
Computed as log(clamp(mel_basis @ magnitude, min=1e-5)).
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
"""
magnitude, phase = self.stft_fn(y)
energy = torch.norm(magnitude, dim=1)
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
log_mel = torch.log(torch.clamp(mel, min=1e-5))
return log_mel, magnitude, phase, energy
class VocoderWithBWE(torch.nn.Module):
"""Vocoder with bandwidth extension (BWE) for higher sample rate output.
Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
"""
def __init__(self, config):
super().__init__()
vocoder_config = config["vocoder"]
bwe_config = config["bwe"]
self.vocoder = Vocoder(config=vocoder_config)
self.bwe_generator = Vocoder(
config={**bwe_config, "apply_final_activation": False}
)
self.input_sample_rate = bwe_config["input_sampling_rate"]
self.output_sample_rate = bwe_config["output_sampling_rate"]
self.hop_length = bwe_config["hop_length"]
self.mel_stft = MelSTFT(
filter_length=bwe_config["n_fft"],
hop_length=bwe_config["hop_length"],
win_length=bwe_config["n_fft"],
n_mel_channels=bwe_config["num_mels"],
sampling_rate=bwe_config["input_sampling_rate"],
mel_fmin=0.0,
mel_fmax=bwe_config["input_sampling_rate"] / 2.0,
)
self.resampler = UpSample1d(
ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"],
persistent=False,
window_type="hann",
)
def _compute_mel(self, audio):
"""Compute log-mel spectrogram from waveform using causal STFT bases."""
B, C, T = audio.shape
flat = audio.reshape(B * C, -1) # (B*C, T)
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
def forward(self, mel_spec):
x = self.vocoder(mel_spec)
_, _, T_low = x.shape
T_out = T_low * self.output_sample_rate // self.input_sample_rate
remainder = T_low % self.hop_length
if remainder != 0:
x = F.pad(x, (0, self.hop_length - remainder))
mel = self._compute_mel(x)
residual = self.bwe_generator(mel)
skip = self.resampler(x)
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
return torch.clamp(residual + skip, -1, 1)[..., :T_out]

View File

@@ -14,6 +14,7 @@ from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.patcher_extension
import comfy.utils
from comfy.ldm.chroma_radiance.layers import NerfEmbedder
def invert_slices(slices, length):
@@ -858,3 +859,267 @@ class NextDiT(nn.Module):
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
return -img
#############################################################################
# Pixel Space Decoder Components #
#############################################################################
def _modulate_shift_scale(x, shift, scale):
return x * (1 + scale) + shift
class PixelResBlock(nn.Module):
"""
Residual block with AdaLN modulation, zero-initialised so it starts as
an identity at the beginning of training.
"""
def __init__(self, channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.in_ln = operations.LayerNorm(channels, eps=1e-6, dtype=dtype, device=device)
self.mlp = nn.Sequential(
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device),
)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1)
h = _modulate_shift_scale(self.in_ln(x), shift, scale)
h = self.mlp(h)
return x + gate * h
class DCTFinalLayer(nn.Module):
"""Zero-initialised output projection (adopted from DiT)."""
def __init__(self, model_channels: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.norm_final(x))
class SimpleMLPAdaLN(nn.Module):
"""
Small MLP decoder head for the pixel-space variant.
Takes per-patch pixel values and a per-patch conditioning vector from the
transformer backbone and predicts the denoised pixel values.
x : [B*N, P^2, C] noisy pixel values per patch position
c : [B*N, dim] backbone hidden state per patch (conditioning)
→ [B*N, P^2, C]
"""
def __init__(
self,
in_channels: int,
model_channels: int,
out_channels: int,
z_channels: int,
num_res_blocks: int,
max_freqs: int = 8,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dtype = dtype
# Project backbone hidden state → per-patch conditioning
self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device)
# Input projection with DCT positional encoding
self.input_embedder = NerfEmbedder(
in_channels=in_channels,
hidden_size_input=model_channels,
max_freqs=max_freqs,
dtype=dtype,
device=device,
operations=operations,
)
# Residual blocks
self.res_blocks = nn.ModuleList([
PixelResBlock(model_channels, dtype=dtype, device=device, operations=operations) for _ in range(num_res_blocks)
])
# Output projection
self.final_layer = DCTFinalLayer(model_channels, out_channels, dtype=dtype, device=device, operations=operations)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
# x: [B*N, 1, P^2*C], c: [B*N, dim]
original_dtype = x.dtype
weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype)
x = self.input_embedder(x) # [B*N, 1, model_channels]
y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels]
x = x.to(weight_dtype)
for block in self.res_blocks:
x = block(x, y)
return self.final_layer(x).to(original_dtype) # [B*N, 1, P^2*C]
#############################################################################
# NextDiT Pixel Space #
#############################################################################
class NextDiTPixelSpace(NextDiT):
"""
Pixel-space variant of NextDiT.
Identical transformer backbone to NextDiT, but the output head is replaced
with a small MLP decoder (SimpleMLPAdaLN) that operates on raw pixel values
per patch rather than a single affine projection.
Key differences vs NextDiT:
• ``final_layer`` is removed; ``dec_net`` (SimpleMLPAdaLN) is used instead.
• ``_forward`` stores the raw patchified pixel values before the backbone
embedding and feeds them to ``dec_net`` together with the per-patch
backbone hidden states.
• Supports optional x0 prediction via ``use_x0``.
"""
def __init__(
self,
# decoder-specific
decoder_hidden_size: int = 3840,
decoder_num_res_blocks: int = 4,
decoder_max_freqs: int = 8,
decoder_in_channels: int = None, # full flattened patch size (patch_size^2 * in_channels)
use_x0: bool = False,
# all NextDiT args forwarded unchanged
**kwargs,
):
super().__init__(**kwargs)
# Remove the latent-space final layer not used in pixel space
del self.final_layer
patch_size = kwargs.get("patch_size", 2)
in_channels = kwargs.get("in_channels", 4)
dim = kwargs.get("dim", 4096)
# decoder_in_channels is the full flattened patch: patch_size^2 * in_channels
dec_in_ch = decoder_in_channels if decoder_in_channels is not None else patch_size ** 2 * in_channels
self.dec_net = SimpleMLPAdaLN(
in_channels=dec_in_ch,
model_channels=decoder_hidden_size,
out_channels=dec_in_ch,
z_channels=dim,
num_res_blocks=decoder_num_res_blocks,
max_freqs=decoder_max_freqs,
dtype=kwargs.get("dtype"),
device=kwargs.get("device"),
operations=kwargs.get("operations"),
)
if use_x0:
self.register_buffer("__x0__", torch.tensor([]))
# ------------------------------------------------------------------
# Forward — mirrors NextDiT._forward exactly, replacing final_layer
# with the pixel-space dec_net decoder.
# ------------------------------------------------------------------
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs):
omni = len(ref_latents) > 0
if omni:
timesteps = torch.cat([timesteps * 0, timesteps], dim=0)
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
t = self.t_embedder(t * self.time_scale, dtype=x.dtype)
adaln_input = t
if self.clip_text_pooled_proj is not None:
pooled = kwargs.get("clip_text_pooled", None)
if pooled is not None:
pooled = self.clip_text_pooled_proj(pooled)
else:
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
# ---- capture raw pixel patches before patchify_and_embed embeds them ----
pH = pW = self.patch_size
B, C, H, W = x.shape
pixel_patches = (
x.view(B, C, H // pH, pH, W // pW, pW)
.permute(0, 2, 4, 3, 5, 1) # [B, Ht, Wt, pH, pW, C]
.flatten(3) # [B, Ht, Wt, pH*pW*C]
.flatten(1, 2) # [B, N, pH*pW*C]
)
N = pixel_patches.shape[1]
# decoder sees one token per patch: [B*N, 1, P^2*C]
pixel_values = pixel_patches.reshape(B * N, 1, pH * pW * C)
patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor)
img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(
x, cap_feats, cap_mask, adaln_input, num_tokens,
ref_latents=ref_latents, ref_contexts=ref_contexts,
siglip_feats=siglip_feats, transformer_options=transformer_options
)
freqs_cis = freqs_cis.to(img.device)
transformer_options["total_blocks"] = len(self.layers)
transformer_options["block_type"] = "double"
img_input = img
for i, layer in enumerate(self.layers):
transformer_options["block_index"] = i
img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
if "img" in out:
img[:, cap_size[0]:] = out["img"]
if "txt" in out:
img[:, :cap_size[0]] = out["txt"]
# ---- pixel-space decoder (replaces final_layer + unpatchify) ----
# img may have padding tokens beyond N; only the first N are real image patches
img_hidden = img[:, cap_size[0]:cap_size[0] + N, :] # [B, N, dim]
decoder_cond = img_hidden.reshape(B * N, self.dim) # [B*N, dim]
output = self.dec_net(pixel_values, decoder_cond) # [B*N, 1, P^2*C]
output = output.reshape(B, N, -1) # [B, N, P^2*C]
# prepend zero cap placeholder so unpatchify indexing works unchanged
cap_placeholder = torch.zeros(
B, cap_size[0], output.shape[-1], device=output.device, dtype=output.dtype
)
img_out = self.unpatchify(
torch.cat([cap_placeholder, output], dim=1),
img_size, cap_size, return_tensor=x_is_tensor
)[:, :, :h, :w]
return -img_out
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
# _forward returns neg_x0 = -x0 (negated decoder output).
#
# Reference inference (working_inference_reference.py):
# out = _forward(img, t) # = -x0
# pred = (img - out) / t # = (img + x0) / t [_apply_x0_residual]
# img += (t_prev - t_curr) * pred # Euler step
#
# ComfyUI's Euler sampler does the same:
# x_next = x + (sigma_next - sigma) * model_output
# So model_output must equal pred = (x - neg_x0) / t = (x - (-x0)) / t = (x + x0) / t
neg_x0 = comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
return (x - neg_x0) / timesteps.view(-1, 1, 1, 1)

View File

@@ -1021,7 +1021,7 @@ class LTXAV(BaseModel):
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()))
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
@@ -1263,6 +1263,11 @@ class Lumina2(BaseModel):
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
return out
class ZImagePixelSpace(Lumina2):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
self.memory_usage_factor_conds = ("ref_latents",)
class WAN21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)

View File

@@ -423,7 +423,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
return dit_config
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
dit_config = {}
dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2
@@ -464,6 +464,29 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if sig_weight is not None:
dit_config["siglip_feat_dim"] = sig_weight.shape[0]
dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix)
if dec_cond_key in state_dict_keys: # pixel-space variant
dit_config["image_model"] = "zimage_pixel"
# patch_size and in_channels are derived from x_embedder:
# x_embedder: Linear(patch_size * patch_size * in_channels, dim)
# The decoder also receives the full flat patch, so decoder_in_channels = x_embedder input dim.
x_emb_in = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1]
dec_out = state_dict['{}dec_net.final_layer.linear.weight'.format(key_prefix)].shape[0]
# patch_size: infer from decoder final layer output matching x_embedder input
# in_channels: infer from dec_net input_embedder (in_features = dec_in_ch + max_freqs^2)
embedder_w = state_dict['{}dec_net.input_embedder.embedder.0.weight'.format(key_prefix)]
dec_in_ch = dec_out # decoder in == decoder out (same pixel space)
dit_config["patch_size"] = round((x_emb_in / 3) ** 0.5) # assume RGB (in_channels=3)
dit_config["in_channels"] = 3
dit_config["decoder_in_channels"] = dec_in_ch
dit_config["decoder_hidden_size"] = state_dict[dec_cond_key].shape[0]
dit_config["decoder_num_res_blocks"] = count_blocks(
state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.'
)
dit_config["decoder_max_freqs"] = int((embedder_w.shape[1] - dec_in_ch) ** 0.5)
if '{}__x0__'.format(key_prefix) in state_dict_keys:
dit_config["use_x0"] = True
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
@@ -533,8 +556,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys and f"{key_prefix}blocks.0.attn1.k_norm.weight" in state_dict_keys: # Hunyuan 3D 2.1
dit_config = {}
dit_config["image_model"] = "hunyuan3d2_1"
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
@@ -1055,6 +1077,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
elif 'noise_refiner.0.attention.norm_k.weight' in state_dict:
n_layers = count_blocks(state_dict, 'layers.{}.')
dim = state_dict['noise_refiner.0.attention.to_k.weight'].shape[0]
sd_map = comfy.utils.z_image_to_diffusers({"n_layers": n_layers, "dim": dim}, output_prefix=output_prefix)
for k in state_dict: # For zeta chroma
if k not in sd_map:
sd_map[k] = k
elif 'x_embedder.weight' in state_dict: #Flux
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')

View File

@@ -32,9 +32,6 @@ import comfy.memory_management
import comfy.utils
import comfy.quant_ops
import comfy_aimdo.torch
import comfy_aimdo.model_vbar
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
@@ -799,6 +796,8 @@ def archive_model_dtypes(model):
for name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
for buf_name, buf in module.named_buffers(recurse=False):
setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype)
def cleanup_models():
@@ -831,11 +830,14 @@ def unet_offload_device():
return torch.device("cpu")
def unet_inital_load_device(parameters, dtype):
cpu_dev = torch.device("cpu")
if comfy.memory_management.aimdo_enabled:
return cpu_dev
torch_dev = get_torch_device()
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
return torch_dev
cpu_dev = torch.device("cpu")
if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
return cpu_dev
@@ -843,7 +845,7 @@ def unet_inital_load_device(parameters, dtype):
mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev)
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled:
if mem_dev > mem_cpu and model_size < mem_dev:
return torch_dev
else:
return cpu_dev
@@ -946,6 +948,9 @@ def text_encoder_device():
return torch.device("cpu")
def text_encoder_initial_device(load_device, offload_device, model_size=0):
if comfy.memory_management.aimdo_enabled:
return offload_device
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
return offload_device
@@ -1206,43 +1211,6 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
if hasattr(weight, "_v"):
#Unexpected usage patterns. There is no reason these don't work but they
#have no testing and no callers do this.
assert r is None
assert stream is None
cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ])
if dtype is None:
dtype = weight._model_dtype
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
v_tensor = weight._v_tensor
else:
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
weight._v_tensor = v_tensor
weight._v_signature = signature
#Send it over
v_tensor.copy_(weight, non_blocking=non_blocking)
return v_tensor.to(dtype=dtype)
r = torch.empty_like(weight, dtype=dtype, device=device)
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
#Offloaded casting could skip this, however it would make the quantizations
#inconsistent between loaded and offloaded weights. So force the double casting
#that would happen in regular flow to make offload deterministic.
cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
cast_buffer.copy_(weight, non_blocking=non_blocking)
weight = cast_buffer
r.copy_(weight, non_blocking=non_blocking)
return r
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
@@ -1698,12 +1666,16 @@ def lora_compute_dtype(device):
return dtype
def synchronize():
if cpu_mode():
return
if is_intel_xpu():
torch.xpu.synchronize()
elif torch.cuda.is_available():
torch.cuda.synchronize()
def soft_empty_cache(force=False):
if cpu_mode():
return
global cpu_state
if cpu_state == CPUState.MPS:
torch.mps.empty_cache()

View File

@@ -241,6 +241,7 @@ class ModelPatcher:
self.patches = {}
self.backup = {}
self.backup_buffers = {}
self.object_patches = {}
self.object_patches_backup = {}
self.weight_wrapper_patches = {}
@@ -306,10 +307,16 @@ class ModelPatcher:
return self.model.lowvram_patch_counter
def get_free_memory(self, device):
return comfy.model_management.get_free_memory(device)
#Prioritize batching (incl. CFG/conds etc) over keeping the model resident. In
#the vast majority of setups a little bit of offloading on the giant model more
#than pays for CFG. So return everything both torch and Aimdo could give us
aimdo_mem = 0
if comfy.memory_management.aimdo_enabled:
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze()
return comfy.model_management.get_free_memory(device) + aimdo_mem
def get_clone_model_override(self):
return self.model, (self.backup, self.object_patches_backup, self.pinned)
return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)
def clone(self, disable_dynamic=False, model_override=None):
class_ = self.__class__
@@ -336,7 +343,7 @@ class ModelPatcher:
n.force_cast_weights = self.force_cast_weights
n.backup, n.object_patches_backup, n.pinned = model_override[1]
n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1]
# attachments
n.attachments = {}
@@ -698,7 +705,7 @@ class ModelPatcher:
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
def _load_list(self, for_dynamic=False, default_device=None):
loading = []
for n, m in self.model.named_modules():
default = False
@@ -726,8 +733,13 @@ class ModelPatcher:
return 0
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else ()
loading.append(prepend + (module_offload_mem, module_mem, n, m, params))
# Dynamic: small weights (<64KB) first, then larger weights prioritized by size.
# Non-dynamic: prioritize by module offload cost.
if for_dynamic:
sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem)
else:
sort_criteria = (module_offload_mem,)
loading.append(sort_criteria + (module_mem, n, m, params))
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@@ -1435,10 +1447,6 @@ class ModelPatcherDynamic(ModelPatcher):
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
#this is now way more dynamic and we dont support the same base model for both Dynamic
#and non-dynamic patchers.
if hasattr(self.model, "model_loaded_weight_memory"):
del self.model.model_loaded_weight_memory
if not hasattr(self.model, "dynamic_vbars"):
self.model.dynamic_vbars = {}
self.non_dynamic_delegate_model = None
@@ -1461,15 +1469,7 @@ class ModelPatcherDynamic(ModelPatcher):
def loaded_size(self):
vbar = self._vbar_get()
if vbar is None:
return 0
return vbar.loaded_size()
def get_free_memory(self, device):
#NOTE: on high condition / batch counts, estimate should have already vacated
#all non-dynamic models so this is safe even if its not 100% true that this
#would all be avaiable for inference use.
return comfy.model_management.get_total_memory(device) - self.model_size()
return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory
#Pinning is deferred to ops time. Assert against this API to avoid pin leaks.
@@ -1504,6 +1504,7 @@ class ModelPatcherDynamic(ModelPatcher):
num_patches = 0
allocated_size = 0
self.model.model_loaded_weight_memory = 0
with self.use_ejected():
self.unpatch_hooks()
@@ -1512,15 +1513,11 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None:
vbar.prioritize()
#We force reserve VRAM for the non comfy-weight so we dont have to deal
#with pin and unpin syncrhonization which can be expensive for small weights
#with a high layer rate (e.g. autoregressive LLMs).
#prioritize the non-comfy weights (note the order reverse).
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
loading.sort(reverse=True)
loading = self._load_list(for_dynamic=True, default_device=device_to)
loading.sort()
for x in loading:
_, _, _, n, m, params = x
*_, module_mem, n, m, params = x
def set_dirty(item, dirty):
if dirty or not hasattr(item, "_v_signature"):
@@ -1558,6 +1555,9 @@ class ModelPatcherDynamic(ModelPatcher):
if key in self.backup:
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
self.patch_weight_to_device(key, device_to=device_to)
weight, _, _ = get_key_weight(self.model, key)
if weight is not None:
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
@@ -1583,21 +1583,26 @@ class ModelPatcherDynamic(ModelPatcher):
for param in params:
key = key_param_name_to_key(n, param)
weight, _, _ = get_key_weight(self.model, key)
weight.seed_key = key
set_dirty(weight, dirty)
geometry = weight
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
weight_size = geometry.numel() * geometry.element_size()
if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size)
weight._model_dtype = model_dtype
allocated_size += weight_size
vbar.set_watermark_limit(allocated_size)
if key not in self.backup:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False)
model_dtype = getattr(m, param + "_comfy_model_dtype", None)
casted_weight = weight.to(dtype=model_dtype, device=device_to)
comfy.utils.set_attr_param(self.model, key, casted_weight)
self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size()
move_weight_functions(m, device_to)
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
for key, buf in self.model.named_buffers(recurse=True):
if key not in self.backup_buffers:
self.backup_buffers[key] = buf
module, buf_name = comfy.utils.resolve_attr(self.model, key)
model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None)
casted_buf = buf.to(dtype=model_dtype, device=device_to)
comfy.utils.set_attr_buffer(self.model, key, casted_buf)
self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size()
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
self.model.device = device_to
self.model.current_weight_patches_uuid = self.patches_uuid
@@ -1613,12 +1618,23 @@ class ModelPatcherDynamic(ModelPatcher):
assert self.load_device != torch.device("cpu")
vbar = self._vbar_get()
return 0 if vbar is None else vbar.free_memory(memory_to_free)
freed = 0 if vbar is None else vbar.free_memory(memory_to_free)
if freed < memory_to_free:
for key in list(self.backup.keys()):
bk = self.backup.pop(key)
comfy.utils.set_attr_param(self.model, key, bk.weight)
for key in list(self.backup_buffers.keys()):
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
freed += self.model.model_loaded_weight_memory
self.model.model_loaded_weight_memory = 0
return freed
def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
for x in loading:
_, _, _, _, m, _ = x
*_, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
if ram_to_unload <= 0:
return
@@ -1640,11 +1656,6 @@ class ModelPatcherDynamic(ModelPatcher):
for m in self.model.modules():
move_weight_functions(m, device_to)
keys = list(self.backup.keys())
for k in keys:
bk = self.backup[k]
comfy.utils.set_attr_param(self.model, k, bk.weight)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above
with self.use_ejected(skip_and_inject_on_exit_only=True):

View File

@@ -80,6 +80,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = None
if s.bias is not None:
bias = s.bias.to(dtype=bias_dtype, copy=True)
return weight, bias, (None, None, None)
offload_stream = None
xfer_dest = None
@@ -269,8 +284,8 @@ def uncast_bias_weight(s, weight, bias, offload_stream):
return
os, weight_a, bias_a = offload_stream
device=None
#FIXME: This is not good RTTI
if not isinstance(weight_a, torch.Tensor):
#FIXME: This is really bad RTTI
if weight_a is not None and not isinstance(weight_a, torch.Tensor):
comfy_aimdo.model_vbar.vbar_unpin(s._v)
device = weight_a
if os is None:
@@ -660,23 +675,29 @@ class fp8_ops(manual_cast):
CUBLAS_IS_AVAILABLE = False
try:
from cublas_ops import CublasLinear
from cublas_ops import CublasLinear, cublas_half_matmul
CUBLAS_IS_AVAILABLE = True
except ImportError:
pass
if CUBLAS_IS_AVAILABLE:
class cublas_ops(disable_weight_init):
class Linear(CublasLinear, disable_weight_init.Linear):
class cublas_ops(manual_cast):
class Linear(CublasLinear, manual_cast.Linear):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
return super().forward(input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = cublas_half_matmul(input, weight, bias, self._epilogue_str, self.has_bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
# ==============================================================================
# Mixed Precision Operations

View File

@@ -428,7 +428,7 @@ class CLIP:
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
self.cond_stage_model.reset_clip_options()
self.load_model()
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
@@ -1467,7 +1467,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
elif clip_type == CLIPType.LTXV:
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data))
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data), **comfy.text_encoders.lt.sd_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif clip_type == CLIPType.NEWBIE:

View File

@@ -1118,6 +1118,20 @@ class ZImage(Lumina2):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
class ZImagePixelSpace(ZImage):
unet_config = {
"image_model": "zimage_pixel",
}
# Pixel-space model: no spatial compression, operates on raw RGB patches.
latent_format = latent_formats.ZImagePixelSpace
# Much lower memory than latent-space models (no VAE, small patches).
memory_usage_factor = 0.03 # TODO: figure out the optimal value for this.
def get_model(self, state_dict, prefix="", device=None):
return model_base.ZImagePixelSpace(self, device=device)
class WAN21_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "wan2.1",
@@ -1720,6 +1734,6 @@ class LongCatImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models += [SVD_img2vid]

View File

@@ -97,18 +97,39 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
class DualLinearProjection(torch.nn.Module):
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
super().__init__()
self.audio_aggregate_embed = operations.Linear(in_dim, out_dim_audio, bias=True, dtype=dtype, device=device)
self.video_aggregate_embed = operations.Linear(in_dim, out_dim_video, bias=True, dtype=dtype, device=device)
def forward(self, x):
source_dim = x.shape[-1]
x = x.movedim(1, -1)
x = (x * torch.rsqrt(torch.mean(x**2, dim=2, keepdim=True) + 1e-6)).flatten(start_dim=2)
video = self.video_aggregate_embed(x * math.sqrt(self.video_aggregate_embed.out_features / source_dim))
audio = self.audio_aggregate_embed(x * math.sqrt(self.audio_aggregate_embed.out_features / source_dim))
return torch.cat((video, audio), dim=-1)
class LTXAVTEModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, text_projection_type="single_linear", model_options={}):
super().__init__()
self.dtypes = set()
self.dtypes.add(dtype)
self.compat_mode = False
self.text_projection_type = text_projection_type
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
self.dtypes.add(dtype_llama)
operations = self.gemma3_12b.operations # TODO
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
if self.text_projection_type == "single_linear":
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
elif self.text_projection_type == "dual_linear":
self.text_embedding_projection = DualLinearProjection(3840 * 49, 4096, 2048, dtype=dtype, device=device, operations=operations)
def enable_compat_mode(self): # TODO: remove
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
@@ -148,18 +169,25 @@ class LTXAVTEModel(torch.nn.Module):
out_device = out.device
if comfy.model_management.should_use_bf16(self.execution_device):
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
out = out.movedim(1, -1).to(self.execution_device)
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out)
out = out.float()
if self.compat_mode:
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
if self.text_projection_type == "single_linear":
out = out.movedim(1, -1).to(self.execution_device)
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out)
return out.to(out_device), pooled
if self.compat_mode:
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
extra = {}
else:
extra = {"unprocessed_ltxav_embeds": True}
elif self.text_projection_type == "dual_linear":
out = self.text_embedding_projection(out)
extra = {"unprocessed_ltxav_embeds": True}
return out.to(device=out_device, dtype=torch.float), pooled, extra
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
@@ -168,7 +196,7 @@ class LTXAVTEModel(torch.nn.Module):
if "model.layers.47.self_attn.q_norm.weight" in sd:
return self.gemma3_12b.load_sd(sd)
else:
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True)
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "text_embedding_projection.": "text_embedding_projection."}, filter_keys=True)
if len(sdo) == 0:
sdo = sd
@@ -206,7 +234,7 @@ class LTXAVTEModel(torch.nn.Module):
num_tokens = max(num_tokens, 642)
return num_tokens * constant * 1024 * 1024
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None, text_projection_type="single_linear"):
class LTXAVTEModel_(LTXAVTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
@@ -214,9 +242,19 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, text_projection_type=text_projection_type, model_options=model_options)
return LTXAVTEModel_
def sd_detect(state_dict_list, prefix=""):
for sd in state_dict_list:
if "{}text_embedding_projection.audio_aggregate_embed.bias".format(prefix) in sd:
return {"text_projection_type": "dual_linear"}
if "{}text_embedding_projection.weight".format(prefix) in sd or "{}text_embedding_projection.aggregate_embed.weight".format(prefix) in sd:
return {"text_projection_type": "single_linear"}
return {}
def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
class Gemma3_12BModel_(Gemma3_12BModel):
def __init__(self, device="cpu", dtype=None, model_options={}):

View File

@@ -869,20 +869,31 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
ATTR_UNSET={}
def set_attr(obj, attr, value):
def resolve_attr(obj, attr):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1], ATTR_UNSET)
return obj, attrs[-1]
def set_attr(obj, attr, value):
obj, name = resolve_attr(obj, attr)
prev = getattr(obj, name, ATTR_UNSET)
if value is ATTR_UNSET:
delattr(obj, attrs[-1])
delattr(obj, name)
else:
setattr(obj, attrs[-1], value)
setattr(obj, name, value)
return prev
def set_attr_param(obj, attr, value):
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def set_attr_buffer(obj, attr, value):
obj, name = resolve_attr(obj, attr)
prev = getattr(obj, name, ATTR_UNSET)
persistent = name not in getattr(obj, "_non_persistent_buffers_set", set())
obj.register_buffer(name, value, persistent=persistent)
return prev
def copy_to_param(obj, attr, value):
# inplace update tensor instead of replacing it
attrs = attr.split(".")

View File

@@ -401,6 +401,7 @@ class VideoFromComponents(VideoInput):
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None,
):
"""Save the video to a file path or BytesIO buffer."""
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
@@ -408,6 +409,10 @@ class VideoFromComponents(VideoInput):
extra_kwargs = {}
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
extra_kwargs["format"] = format.value
elif isinstance(path, io.BytesIO):
# BytesIO has no file extension, so av.open can't infer the format.
# Default to mp4 since that's the only supported format anyway.
extra_kwargs["format"] = "mp4"
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
# Add metadata before writing any streams
if metadata is not None:

View File

@@ -1240,6 +1240,19 @@ class BoundingBox(ComfyTypeIO):
return d
@comfytype(io_type="CURVE")
class Curve(ComfyTypeIO):
CurvePoint = tuple[float, float]
Type = list[CurvePoint]
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: list[tuple[float, float]]=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = [(0.0, 0.0), (1.0, 1.0)]
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func
@@ -2226,5 +2239,6 @@ __all__ = [
"PriceBadgeDepends",
"PriceBadge",
"BoundingBox",
"Curve",
"NodeReplace",
]

View File

@@ -7,7 +7,8 @@ class ImageGenerationRequest(BaseModel):
aspect_ratio: str = Field(...)
n: int = Field(...)
seed: int = Field(...)
response_for: str = Field("url")
response_format: str = Field("url")
resolution: str = Field(...)
class InputUrlObject(BaseModel):
@@ -16,12 +17,13 @@ class InputUrlObject(BaseModel):
class ImageEditRequest(BaseModel):
model: str = Field(...)
image: InputUrlObject = Field(...)
images: list[InputUrlObject] = Field(...)
prompt: str = Field(...)
resolution: str = Field(...)
n: int = Field(...)
seed: int = Field(...)
response_for: str = Field("url")
response_format: str = Field("url")
aspect_ratio: str | None = Field(...)
class VideoGenerationRequest(BaseModel):
@@ -47,8 +49,13 @@ class ImageResponseObject(BaseModel):
revised_prompt: str | None = Field(None)
class UsageObject(BaseModel):
cost_in_usd_ticks: int | None = Field(None)
class ImageGenerationResponse(BaseModel):
data: list[ImageResponseObject] = Field(...)
usage: UsageObject | None = Field(None)
class VideoGenerationResponse(BaseModel):
@@ -65,3 +72,4 @@ class VideoStatusResponse(BaseModel):
status: str | None = Field(None)
video: VideoResponseObject | None = Field(None)
model: str | None = Field(None)
usage: UsageObject | None = Field(None)

View File

@@ -148,3 +148,4 @@ class MotionControlRequest(BaseModel):
keep_original_sound: str = Field(...)
character_orientation: str = Field(...)
mode: str = Field(..., description="'pro' or 'std'")
model_name: str = Field(...)

View File

@@ -27,6 +27,12 @@ from comfy_api_nodes.util import (
)
def _extract_grok_price(response) -> float | None:
if response.usage and response.usage.cost_in_usd_ticks is not None:
return response.usage.cost_in_usd_ticks / 10_000_000_000
return None
class GrokImageNode(IO.ComfyNode):
@classmethod
@@ -37,7 +43,10 @@ class GrokImageNode(IO.ComfyNode):
category="api node/image/Grok",
description="Generate images using Grok based on a text prompt",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-image-beta"]),
IO.Combo.Input(
"model",
options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"],
),
IO.String.Input(
"prompt",
multiline=True,
@@ -81,6 +90,7 @@ class GrokImageNode(IO.ComfyNode):
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
IO.Combo.Input("resolution", options=["1K", "2K"], optional=True),
],
outputs=[
IO.Image.Output(),
@@ -92,8 +102,13 @@ class GrokImageNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]),
expr="""{"type":"usd","usd":0.033 * widgets.number_of_images}""",
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]),
expr="""
(
$rate := $contains(widgets.model, "pro") ? 0.07 : 0.02;
{"type":"usd","usd": $rate * widgets.number_of_images}
)
""",
),
)
@@ -105,6 +120,7 @@ class GrokImageNode(IO.ComfyNode):
aspect_ratio: str,
number_of_images: int,
seed: int,
resolution: str = "1K",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
response = await sync_op(
@@ -116,8 +132,10 @@ class GrokImageNode(IO.ComfyNode):
aspect_ratio=aspect_ratio,
n=number_of_images,
seed=seed,
resolution=resolution.lower(),
),
response_model=ImageGenerationResponse,
price_extractor=_extract_grok_price,
)
if len(response.data) == 1:
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
@@ -138,14 +156,17 @@ class GrokImageEditNode(IO.ComfyNode):
category="api node/image/Grok",
description="Modify an existing image based on a text prompt",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-image-beta"]),
IO.Image.Input("image"),
IO.Combo.Input(
"model",
options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"],
),
IO.Image.Input("image", display_name="images"),
IO.String.Input(
"prompt",
multiline=True,
tooltip="The text prompt used to generate the image",
),
IO.Combo.Input("resolution", options=["1K"]),
IO.Combo.Input("resolution", options=["1K", "2K"]),
IO.Int.Input(
"number_of_images",
default=1,
@@ -166,6 +187,27 @@ class GrokImageEditNode(IO.ComfyNode):
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
IO.Combo.Input(
"aspect_ratio",
options=[
"auto",
"1:1",
"2:3",
"3:2",
"3:4",
"4:3",
"9:16",
"16:9",
"9:19.5",
"19.5:9",
"9:20",
"20:9",
"1:2",
"2:1",
],
optional=True,
tooltip="Only allowed when multiple images are connected to the image input.",
),
],
outputs=[
IO.Image.Output(),
@@ -177,8 +219,13 @@ class GrokImageEditNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]),
expr="""{"type":"usd","usd":0.002 + 0.033 * widgets.number_of_images}""",
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]),
expr="""
(
$rate := $contains(widgets.model, "pro") ? 0.07 : 0.02;
{"type":"usd","usd": 0.002 + $rate * widgets.number_of_images}
)
""",
),
)
@@ -191,22 +238,32 @@ class GrokImageEditNode(IO.ComfyNode):
resolution: str,
number_of_images: int,
seed: int,
aspect_ratio: str = "auto",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
if get_number_of_images(image) != 1:
raise ValueError("Only one input image is supported.")
if model == "grok-imagine-image-pro":
if get_number_of_images(image) > 1:
raise ValueError("The pro model supports only 1 input image.")
elif get_number_of_images(image) > 3:
raise ValueError("A maximum of 3 input images is supported.")
if aspect_ratio != "auto" and get_number_of_images(image) == 1:
raise ValueError(
"Custom aspect ratio is only allowed when multiple images are connected to the image input."
)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"),
data=ImageEditRequest(
model=model,
image=InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}"),
images=[InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(i)}") for i in image],
prompt=prompt,
resolution=resolution.lower(),
n=number_of_images,
seed=seed,
aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio,
),
response_model=ImageGenerationResponse,
price_extractor=_extract_grok_price,
)
if len(response.data) == 1:
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
@@ -227,7 +284,7 @@ class GrokVideoNode(IO.ComfyNode):
category="api node/video/Grok",
description="Generate video from a prompt or an image",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-video-beta"]),
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
IO.String.Input(
"prompt",
multiline=True,
@@ -275,10 +332,11 @@ class GrokVideoNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration"], inputs=["image"]),
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]),
expr="""
(
$base := 0.181 * widgets.duration;
$rate := widgets.resolution = "720p" ? 0.07 : 0.05;
$base := $rate * widgets.duration;
{"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base}
)
""",
@@ -321,6 +379,7 @@ class GrokVideoNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
price_extractor=_extract_grok_price,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
@@ -335,7 +394,7 @@ class GrokVideoEditNode(IO.ComfyNode):
category="api node/video/Grok",
description="Edit an existing video based on a text prompt.",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-video-beta"]),
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
IO.String.Input(
"prompt",
multiline=True,
@@ -364,7 +423,7 @@ class GrokVideoEditNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd": 0.191, "format": {"suffix": "/sec", "approximate": true}}""",
expr="""{"type":"usd","usd": 0.06, "format": {"suffix": "/sec", "approximate": true}}""",
),
)
@@ -398,6 +457,7 @@ class GrokVideoEditNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
price_extractor=_extract_grok_price,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))

View File

@@ -2747,6 +2747,7 @@ class MotionControl(IO.ComfyNode):
"but the character orientation matches the reference image (camera/other details via prompt).",
),
IO.Combo.Input("mode", options=["pro", "std"]),
IO.Combo.Input("model", options=["kling-v3", "kling-v2-6"], optional=True),
],
outputs=[
IO.Video.Output(),
@@ -2777,6 +2778,7 @@ class MotionControl(IO.ComfyNode):
keep_original_sound: bool,
character_orientation: str,
mode: str,
model: str = "kling-v2-6",
) -> IO.NodeOutput:
validate_string(prompt, max_length=2500)
validate_image_dimensions(reference_image, min_width=340, min_height=340)
@@ -2797,6 +2799,7 @@ class MotionControl(IO.ComfyNode):
keep_original_sound="yes" if keep_original_sound else "no",
character_orientation=character_orientation,
mode=mode,
model_name=model,
),
)
if response.code:

View File

@@ -96,7 +96,7 @@ class VAEEncodeAudio(IO.ComfyNode):
def vae_decode_audio(vae, samples, tile=None, overlap=None):
if tile is not None:
audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1)
audio = vae.decode_tiled(samples["samples"], tile_x=tile, tile_y=tile, overlap=overlap).movedim(-1, 1)
else:
audio = vae.decode(samples["samples"]).movedim(-1, 1)

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.15.1"
__version__ = "0.16.2"

View File

@@ -876,12 +876,14 @@ async def validate_inputs(prompt_id, prompt, item, validated):
continue
else:
try:
# Unwraps values wrapped in __value__ key. This is used to pass
# list widget value to execution, as by default list value is
# reserved to represent the connection between nodes.
if isinstance(val, dict) and "__value__" in val:
val = val["__value__"]
inputs[x] = val
# Unwraps values wrapped in __value__ key or typed wrapper.
# This is used to pass list widget values to execution,
# as by default list value is reserved to represent the
# connection between nodes.
if isinstance(val, dict):
if "__value__" in val:
val = val["__value__"]
inputs[x] = val
if input_type == "INT":
val = int(val)

10
main.py
View File

@@ -16,11 +16,6 @@ from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
from comfy_api import feature_flags
import comfy_aimdo.control
if enables_dynamic_vram():
comfy_aimdo.control.init()
if __name__ == "__main__":
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
@@ -28,6 +23,11 @@ if __name__ == "__main__":
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
import comfy_aimdo.control
if enables_dynamic_vram():
comfy_aimdo.control.init()
if os.name == "nt":
os.environ['MIMALLOC_PURGE_DELAY'] = '0'

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.15.1"
version = "0.16.2"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.39.19
comfyui-workflow-templates==0.9.5
comfyui-workflow-templates==0.9.10
comfyui-embedded-docs==0.4.3
torch
torchsde
@@ -22,7 +22,7 @@ alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.7
comfy-aimdo>=0.2.4
comfy-aimdo>=0.2.7
requests
#non essential dependencies: