mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 22:30:00 +00:00
Compare commits
29 Commits
automation
...
v0.16.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
50549aa252 | ||
|
|
1c3b651c0a | ||
|
|
5073da57ad | ||
|
|
42e0e023ee | ||
|
|
6481569ad4 | ||
|
|
6ef82a89b8 | ||
|
|
da29b797ce | ||
|
|
9cdfd7403b | ||
|
|
bd21363563 | ||
|
|
e04d0dbeb8 | ||
|
|
c8428541a6 | ||
|
|
4941671b5a | ||
|
|
c5fe8ace68 | ||
|
|
f2ee7f2d36 | ||
|
|
43c64b6308 | ||
|
|
ac4a943ff3 | ||
|
|
8811db52db | ||
|
|
0a7446ade4 | ||
|
|
9b85cf9558 | ||
|
|
d531e3fb2a | ||
|
|
eb011733b6 | ||
|
|
ac6513e142 | ||
|
|
b6ddc590ed | ||
|
|
f719a9d928 | ||
|
|
174fd6759d | ||
|
|
09bcbddfcf | ||
|
|
dff0a4a158 | ||
|
|
9ebee0a217 | ||
|
|
57dd6c1aad |
@@ -776,3 +776,10 @@ class ChromaRadiance(LatentFormat):
|
||||
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
|
||||
class ZImagePixelSpace(ChromaRadiance):
|
||||
"""Pixel-space latent format for ZImage DCT variant.
|
||||
No VAE encoding/decoding — the model operates directly on RGB pixels.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -2,11 +2,16 @@ from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.lightricks.model import (
|
||||
ADALN_BASE_PARAMS_COUNT,
|
||||
ADALN_CROSS_ATTN_PARAMS_COUNT,
|
||||
CrossAttention,
|
||||
FeedForward,
|
||||
AdaLayerNormSingle,
|
||||
PixArtAlphaTextProjection,
|
||||
NormSingleLinearTextProjection,
|
||||
LTXVModel,
|
||||
apply_cross_attention_adaln,
|
||||
compute_prompt_timestep,
|
||||
)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
@@ -87,6 +92,8 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
v_context_dim=None,
|
||||
a_context_dim=None,
|
||||
attn_precision=None,
|
||||
apply_gated_attention=False,
|
||||
cross_attention_adaln=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -94,6 +101,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.cross_attention_adaln = cross_attention_adaln
|
||||
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
@@ -101,6 +109,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
dim_head=vd_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -111,6 +120,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
dim_head=ad_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -122,6 +132,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=v_heads,
|
||||
dim_head=vd_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -132,6 +143,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -144,6 +156,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -156,6 +169,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -168,11 +182,16 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
|
||||
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype))
|
||||
self.audio_scale_shift_table = nn.Parameter(
|
||||
torch.empty(6, a_dim, device=device, dtype=dtype)
|
||||
torch.empty(num_ada_params, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
if cross_attention_adaln:
|
||||
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype))
|
||||
self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype))
|
||||
|
||||
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
||||
torch.empty(5, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
@@ -215,10 +234,30 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
|
||||
return (*scale_shift_ada_values, *gate_ada_values)
|
||||
|
||||
def _apply_text_cross_attention(
|
||||
self, x, context, attn, scale_shift_table, prompt_scale_shift_table,
|
||||
timestep, prompt_timestep, attention_mask, transformer_options,
|
||||
):
|
||||
"""Apply text cross-attention, with optional ADaLN modulation."""
|
||||
if self.cross_attention_adaln:
|
||||
shift_q, scale_q, gate = self.get_ada_values(
|
||||
scale_shift_table, x.shape[0], timestep, slice(6, 9)
|
||||
)
|
||||
return apply_cross_attention_adaln(
|
||||
x, context, attn, shift_q, scale_q, gate,
|
||||
prompt_scale_shift_table, prompt_timestep,
|
||||
attention_mask, transformer_options,
|
||||
)
|
||||
return attn(
|
||||
comfy.ldm.common_dit.rms_norm(x), context=context,
|
||||
mask=attention_mask, transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
|
||||
v_prompt_timestep=None, a_prompt_timestep=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
run_vx = transformer_options.get("run_vx", True)
|
||||
run_ax = transformer_options.get("run_ax", True)
|
||||
@@ -240,7 +279,11 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||
vx.addcmul_(attn1_out, vgate_msa)
|
||||
del vgate_msa, attn1_out
|
||||
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
vx.add_(self._apply_text_cross_attention(
|
||||
vx, v_context, self.attn2, self.scale_shift_table,
|
||||
getattr(self, 'prompt_scale_shift_table', None),
|
||||
v_timestep, v_prompt_timestep, attention_mask, transformer_options,)
|
||||
)
|
||||
|
||||
# audio
|
||||
if run_ax:
|
||||
@@ -254,7 +297,11 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
||||
ax.addcmul_(attn1_out, agate_msa)
|
||||
del agate_msa, attn1_out
|
||||
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
ax.add_(self._apply_text_cross_attention(
|
||||
ax, a_context, self.audio_attn2, self.audio_scale_shift_table,
|
||||
getattr(self, 'audio_prompt_scale_shift_table', None),
|
||||
a_timestep, a_prompt_timestep, attention_mask, transformer_options,)
|
||||
)
|
||||
|
||||
# video - audio cross attention.
|
||||
if run_a2v or run_v2a:
|
||||
@@ -351,6 +398,9 @@ class LTXAVModel(LTXVModel):
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier=1000.0,
|
||||
av_ca_timestep_scale_multiplier=1.0,
|
||||
apply_gated_attention=False,
|
||||
caption_proj_before_connector=False,
|
||||
cross_attention_adaln=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -362,6 +412,7 @@ class LTXAVModel(LTXVModel):
|
||||
self.audio_attention_head_dim = audio_attention_head_dim
|
||||
self.audio_num_attention_heads = audio_num_attention_heads
|
||||
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
||||
self.apply_gated_attention = apply_gated_attention
|
||||
|
||||
# Calculate audio dimensions
|
||||
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
||||
@@ -386,6 +437,8 @@ class LTXAVModel(LTXVModel):
|
||||
vae_scale_factors=vae_scale_factors,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
caption_proj_before_connector=caption_proj_before_connector,
|
||||
cross_attention_adaln=cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -400,14 +453,28 @@ class LTXAVModel(LTXVModel):
|
||||
)
|
||||
|
||||
# Audio-specific AdaLN
|
||||
audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.audio_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
embedding_coefficient=audio_embedding_coefficient,
|
||||
use_additional_conditions=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
if self.cross_attention_adaln:
|
||||
self.audio_prompt_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
embedding_coefficient=2,
|
||||
use_additional_conditions=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
else:
|
||||
self.audio_prompt_adaln_single = None
|
||||
|
||||
num_scale_shift_values = 4
|
||||
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
@@ -443,35 +510,73 @@ class LTXAVModel(LTXVModel):
|
||||
)
|
||||
|
||||
# Audio caption projection
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
if self.caption_proj_before_connector:
|
||||
if self.caption_projection_first_linear:
|
||||
self.audio_caption_projection = NormSingleLinearTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
else:
|
||||
self.audio_caption_projection = lambda a: a
|
||||
else:
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
connector_split_rope = kwargs.get("rope_type", "split") == "split"
|
||||
connector_gated_attention = kwargs.get("connector_apply_gated_attention", False)
|
||||
attention_head_dim = kwargs.get("connector_attention_head_dim", 128)
|
||||
num_attention_heads = kwargs.get("connector_num_attention_heads", 30)
|
||||
num_layers = kwargs.get("connector_num_layers", 2)
|
||||
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
split_rope=True,
|
||||
attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim),
|
||||
num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads),
|
||||
num_layers=num_layers,
|
||||
split_rope=connector_split_rope,
|
||||
double_precision_rope=True,
|
||||
apply_gated_attention=connector_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
split_rope=True,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_layers=num_layers,
|
||||
split_rope=connector_split_rope,
|
||||
double_precision_rope=True,
|
||||
apply_gated_attention=connector_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
def preprocess_text_embeds(self, context):
|
||||
if context.shape[-1] == self.caption_channels * 2:
|
||||
return context
|
||||
out_vid = self.video_embeddings_connector(context)[0]
|
||||
out_audio = self.audio_embeddings_connector(context)[0]
|
||||
def preprocess_text_embeds(self, context, unprocessed=False):
|
||||
# LTXv2 fully processed context has dimension of self.caption_channels * 2
|
||||
# LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim
|
||||
if not unprocessed:
|
||||
if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2):
|
||||
return context
|
||||
if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim:
|
||||
context_vid = context[:, :, :self.cross_attention_dim]
|
||||
context_audio = context[:, :, self.cross_attention_dim:]
|
||||
else:
|
||||
context_vid = context
|
||||
context_audio = context
|
||||
if self.caption_proj_before_connector:
|
||||
context_vid = self.caption_projection(context_vid)
|
||||
context_audio = self.audio_caption_projection(context_audio)
|
||||
out_vid = self.video_embeddings_connector(context_vid)[0]
|
||||
out_audio = self.audio_embeddings_connector(context_audio)[0]
|
||||
return torch.concat((out_vid, out_audio), dim=-1)
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
@@ -487,6 +592,8 @@ class LTXAVModel(LTXVModel):
|
||||
ad_head=self.audio_attention_head_dim,
|
||||
v_context_dim=self.cross_attention_dim,
|
||||
a_context_dim=self.audio_cross_attention_dim,
|
||||
apply_gated_attention=self.apply_gated_attention,
|
||||
cross_attention_adaln=self.cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
@@ -608,6 +715,10 @@ class LTXAVModel(LTXVModel):
|
||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||
|
||||
v_prompt_timestep = compute_prompt_timestep(
|
||||
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
||||
)
|
||||
|
||||
# Prepare audio timestep
|
||||
a_timestep = kwargs.get("a_timestep")
|
||||
if a_timestep is not None:
|
||||
@@ -618,25 +729,25 @@ class LTXAVModel(LTXVModel):
|
||||
|
||||
# Cross-attention timesteps - compress these too
|
||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||
a_timestep_flat,
|
||||
timestep.max().expand_as(a_timestep_flat),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||
timestep_flat,
|
||||
a_timestep.max().expand_as(timestep_flat),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||
timestep_flat * av_ca_factor,
|
||||
a_timestep.max().expand_as(timestep_flat) * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||
a_timestep_flat * av_ca_factor,
|
||||
timestep.max().expand_as(a_timestep_flat) * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
@@ -660,29 +771,40 @@ class LTXAVModel(LTXVModel):
|
||||
# Audio timesteps
|
||||
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
||||
|
||||
a_prompt_timestep = compute_prompt_timestep(
|
||||
self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype
|
||||
)
|
||||
else:
|
||||
a_timestep = timestep_scaled
|
||||
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||
cross_av_timestep_ss = []
|
||||
a_prompt_timestep = None
|
||||
|
||||
return [v_timestep, a_timestep, cross_av_timestep_ss], [
|
||||
return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [
|
||||
v_embedded_timestep,
|
||||
a_embedded_timestep,
|
||||
]
|
||||
], None
|
||||
|
||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
video_dim = vx.shape[-1]
|
||||
audio_dim = ax.shape[-1]
|
||||
|
||||
v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim
|
||||
a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim
|
||||
|
||||
v_context, a_context = torch.split(
|
||||
context, int(context.shape[-1] / 2), len(context.shape) - 1
|
||||
context, [v_context_dim, a_context_dim], len(context.shape) - 1
|
||||
)
|
||||
|
||||
v_context, attention_mask = super()._prepare_context(
|
||||
v_context, batch_size, vx, attention_mask
|
||||
)
|
||||
if self.audio_caption_projection is not None:
|
||||
if self.caption_proj_before_connector is False:
|
||||
a_context = self.audio_caption_projection(a_context)
|
||||
a_context = a_context.view(batch_size, -1, ax.shape[-1])
|
||||
a_context = a_context.view(batch_size, -1, audio_dim)
|
||||
|
||||
return [v_context, a_context], attention_mask
|
||||
|
||||
@@ -744,6 +866,9 @@ class LTXAVModel(LTXVModel):
|
||||
av_ca_v2a_gate_noise_timestep,
|
||||
) = timestep[2]
|
||||
|
||||
v_prompt_timestep = timestep[3]
|
||||
a_prompt_timestep = timestep[4]
|
||||
|
||||
"""Process transformer blocks for LTXAV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
@@ -771,6 +896,8 @@ class LTXAVModel(LTXVModel):
|
||||
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||
transformer_options=args["transformer_options"],
|
||||
self_attention_mask=args.get("self_attention_mask"),
|
||||
v_prompt_timestep=args.get("v_prompt_timestep"),
|
||||
a_prompt_timestep=args.get("a_prompt_timestep"),
|
||||
)
|
||||
return out
|
||||
|
||||
@@ -792,6 +919,8 @@ class LTXAVModel(LTXVModel):
|
||||
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||
"transformer_options": transformer_options,
|
||||
"self_attention_mask": self_attention_mask,
|
||||
"v_prompt_timestep": v_prompt_timestep,
|
||||
"a_prompt_timestep": a_prompt_timestep,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
@@ -814,6 +943,8 @@ class LTXAVModel(LTXVModel):
|
||||
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||
transformer_options=transformer_options,
|
||||
self_attention_mask=self_attention_mask,
|
||||
v_prompt_timestep=v_prompt_timestep,
|
||||
a_prompt_timestep=a_prompt_timestep,
|
||||
)
|
||||
|
||||
return [vx, ax]
|
||||
|
||||
@@ -50,6 +50,7 @@ class BasicTransformerBlock1D(nn.Module):
|
||||
d_head,
|
||||
context_dim=None,
|
||||
attn_precision=None,
|
||||
apply_gated_attention=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -63,6 +64,7 @@ class BasicTransformerBlock1D(nn.Module):
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
context_dim=None,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -121,6 +123,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
positional_embedding_max_pos=[4096],
|
||||
causal_temporal_positioning=False,
|
||||
num_learnable_registers: Optional[int] = 128,
|
||||
apply_gated_attention=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -145,6 +148,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
|
||||
@@ -275,6 +275,30 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NormSingleLinearTextProjection(nn.Module):
|
||||
"""Text projection for 20B models - single linear with RMSNorm (no activation)."""
|
||||
|
||||
def __init__(
|
||||
self, in_features, hidden_size, dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
if operations is None:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
self.in_norm = operations.RMSNorm(
|
||||
in_features, eps=1e-6, elementwise_affine=False
|
||||
)
|
||||
self.linear_1 = operations.Linear(
|
||||
in_features, hidden_size, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
self.in_features = in_features
|
||||
|
||||
def forward(self, caption):
|
||||
caption = self.in_norm(caption)
|
||||
caption = caption * (self.hidden_size / self.in_features) ** 0.5
|
||||
return self.linear_1(caption)
|
||||
|
||||
|
||||
class GELU_approx(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
@@ -343,6 +367,7 @@ class CrossAttention(nn.Module):
|
||||
dim_head=64,
|
||||
dropout=0.0,
|
||||
attn_precision=None,
|
||||
apply_gated_attention=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -362,6 +387,12 @@ class CrossAttention(nn.Module):
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
# Optional per-head gating
|
||||
if apply_gated_attention:
|
||||
self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_gate_logits = None
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
||||
)
|
||||
@@ -383,16 +414,30 @@ class CrossAttention(nn.Module):
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
|
||||
# Apply per-head gating if enabled
|
||||
if self.to_gate_logits is not None:
|
||||
gate_logits = self.to_gate_logits(x) # (B, T, H)
|
||||
b, t, _ = out.shape
|
||||
out = out.view(b, t, self.heads, self.dim_head)
|
||||
gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity
|
||||
out = out * gates.unsqueeze(-1)
|
||||
out = out.view(b, t, self.heads * self.dim_head)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate)
|
||||
ADALN_BASE_PARAMS_COUNT = 6
|
||||
ADALN_CROSS_ATTN_PARAMS_COUNT = 9
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
|
||||
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.cross_attention_adaln = cross_attention_adaln
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
@@ -416,18 +461,25 @@ class BasicTransformerBlock(nn.Module):
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
if cross_attention_adaln:
|
||||
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype))
|
||||
|
||||
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||
attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options)
|
||||
x.addcmul_(attn1_input, gate_msa)
|
||||
del attn1_input
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2)
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa
|
||||
|
||||
if self.cross_attention_adaln:
|
||||
shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2)
|
||||
x += apply_cross_attention_adaln(
|
||||
x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca,
|
||||
self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options,
|
||||
)
|
||||
else:
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x)
|
||||
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
||||
@@ -435,6 +487,47 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype):
|
||||
"""Compute a single global prompt timestep for cross-attention ADaLN.
|
||||
|
||||
Uses the max across tokens (matching JAX max_per_segment) and broadcasts
|
||||
over text tokens. Returns None when *adaln_module* is None.
|
||||
"""
|
||||
if adaln_module is None:
|
||||
return None
|
||||
ts_input = (
|
||||
timestep_scaled.max(dim=1, keepdim=True).values.flatten()
|
||||
if timestep_scaled.dim() > 1
|
||||
else timestep_scaled.flatten()
|
||||
)
|
||||
prompt_ts, _ = adaln_module(
|
||||
ts_input,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1])
|
||||
|
||||
|
||||
def apply_cross_attention_adaln(
|
||||
x, context, attn, q_shift, q_scale, q_gate,
|
||||
prompt_scale_shift_table, prompt_timestep,
|
||||
attention_mask=None, transformer_options={},
|
||||
):
|
||||
"""Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV).
|
||||
|
||||
Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so
|
||||
that both regular tensors and CompressedTimestep are supported.
|
||||
"""
|
||||
batch_size = x.shape[0]
|
||||
shift_kv, scale_kv = (
|
||||
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
|
||||
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
|
||||
).unbind(dim=2)
|
||||
attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift
|
||||
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
|
||||
return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate
|
||||
|
||||
def get_fractional_positions(indices_grid, max_pos):
|
||||
n_pos_dims = indices_grid.shape[1]
|
||||
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
||||
@@ -556,6 +649,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
vae_scale_factors: tuple = (8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier = 1000.0,
|
||||
caption_proj_before_connector=False,
|
||||
cross_attention_adaln=False,
|
||||
caption_projection_first_linear=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -582,6 +678,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
self.operations = operations
|
||||
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||
self.caption_proj_before_connector = caption_proj_before_connector
|
||||
self.cross_attention_adaln = cross_attention_adaln
|
||||
self.caption_projection_first_linear = caption_projection_first_linear
|
||||
|
||||
# Common dimensions
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
@@ -609,17 +708,37 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||
self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||
)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
if self.cross_attention_adaln:
|
||||
self.prompt_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||
)
|
||||
else:
|
||||
self.prompt_adaln_single = None
|
||||
|
||||
if self.caption_proj_before_connector:
|
||||
if self.caption_projection_first_linear:
|
||||
self.caption_projection = NormSingleLinearTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
else:
|
||||
self.caption_projection = lambda a: a
|
||||
else:
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
@@ -665,9 +784,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
timestep_scaled.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
@@ -677,14 +796,18 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
||||
|
||||
return timestep, embedded_timestep
|
||||
prompt_timestep = compute_prompt_timestep(
|
||||
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
||||
)
|
||||
|
||||
return timestep, embedded_timestep, prompt_timestep
|
||||
|
||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||
"""Prepare context for transformer blocks."""
|
||||
if self.caption_projection is not None:
|
||||
if self.caption_proj_before_connector is False:
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(batch_size, -1, x.shape[-1])
|
||||
|
||||
context = context.view(batch_size, -1, x.shape[-1])
|
||||
return context, attention_mask
|
||||
|
||||
def _precompute_freqs_cis(
|
||||
@@ -792,7 +915,8 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
merged_args.update(additional_args)
|
||||
|
||||
# Prepare timestep and context
|
||||
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||
timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||
merged_args["prompt_timestep"] = prompt_timestep
|
||||
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
||||
|
||||
# Prepare attention mask and positional embeddings
|
||||
@@ -833,7 +957,9 @@ class LTXVModel(LTXBaseModel):
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier = 1000.0,
|
||||
timestep_scale_multiplier=1000.0,
|
||||
caption_proj_before_connector=False,
|
||||
cross_attention_adaln=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -852,6 +978,8 @@ class LTXVModel(LTXBaseModel):
|
||||
vae_scale_factors=vae_scale_factors,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
caption_proj_before_connector=caption_proj_before_connector,
|
||||
cross_attention_adaln=cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@@ -860,7 +988,6 @@ class LTXVModel(LTXBaseModel):
|
||||
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
"""Initialize LTXV-specific components."""
|
||||
# No additional components needed for LTXV beyond base class
|
||||
pass
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
@@ -872,6 +999,7 @@ class LTXVModel(LTXBaseModel):
|
||||
self.num_attention_heads,
|
||||
self.attention_head_dim,
|
||||
context_dim=self.cross_attention_dim,
|
||||
cross_attention_adaln=self.cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
@@ -1149,16 +1277,17 @@ class LTXVModel(LTXBaseModel):
|
||||
"""Process transformer blocks for LTXV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
prompt_timestep = kwargs.get("prompt_timestep", None)
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"))
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(
|
||||
@@ -1169,6 +1298,7 @@ class LTXVModel(LTXBaseModel):
|
||||
pe=pe,
|
||||
transformer_options=transformer_options,
|
||||
self_attention_mask=self_attention_mask,
|
||||
prompt_timestep=prompt_timestep,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
@@ -13,7 +13,7 @@ from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
||||
CausalityAxis,
|
||||
CausalAudioAutoencoder,
|
||||
)
|
||||
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
|
||||
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE
|
||||
|
||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||
|
||||
@@ -141,7 +141,10 @@ class AudioVAE(torch.nn.Module):
|
||||
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
||||
|
||||
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||
if "bwe" in component_config.vocoder:
|
||||
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
|
||||
else:
|
||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||
|
||||
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||
|
||||
@@ -822,26 +822,23 @@ class CausalAudioAutoencoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = self._guess_config()
|
||||
config = self.get_default_config()
|
||||
|
||||
# Extract encoder and decoder configs from the new format
|
||||
model_config = config.get("model", {}).get("params", {})
|
||||
variables_config = config.get("variables", {})
|
||||
|
||||
self.sampling_rate = variables_config.get(
|
||||
"sampling_rate",
|
||||
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
|
||||
self.sampling_rate = model_config.get(
|
||||
"sampling_rate", config.get("sampling_rate", 16000)
|
||||
)
|
||||
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
|
||||
decoder_config = model_config.get("decoder", encoder_config)
|
||||
|
||||
# Load mel spectrogram parameters
|
||||
self.mel_bins = encoder_config.get("mel_bins", 64)
|
||||
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||
self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||
self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||
|
||||
# Store causality configuration at VAE level (not just in encoder internals)
|
||||
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
|
||||
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value)
|
||||
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
|
||||
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
|
||||
|
||||
@@ -850,44 +847,38 @@ class CausalAudioAutoencoder(nn.Module):
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def _guess_config(self):
|
||||
encoder_config = {
|
||||
# Required parameters - based on ltx-video-av-1679000 model metadata
|
||||
"ch": 128,
|
||||
"out_ch": 8,
|
||||
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": [], # Based on metadata: empty list, no attention
|
||||
"dropout": 0.0,
|
||||
"resamp_with_conv": True,
|
||||
"in_channels": 2, # stereo
|
||||
"resolution": 256,
|
||||
"z_channels": 8,
|
||||
def get_default_config(self):
|
||||
ddconfig = {
|
||||
"double_z": True,
|
||||
"attn_type": "vanilla",
|
||||
"mid_block_add_attention": False, # Based on metadata: false
|
||||
"mel_bins": 64,
|
||||
"z_channels": 8,
|
||||
"resolution": 256,
|
||||
"downsample_time": False,
|
||||
"in_channels": 2,
|
||||
"out_ch": 2,
|
||||
"ch": 128,
|
||||
"ch_mult": [1, 2, 4],
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": [],
|
||||
"dropout": 0.0,
|
||||
"mid_block_add_attention": False,
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height", # Based on metadata
|
||||
"mel_bins": 64, # Based on metadata: mel_bins = 64
|
||||
}
|
||||
|
||||
decoder_config = {
|
||||
# Inherits encoder config, can override specific params
|
||||
**encoder_config,
|
||||
"out_ch": 2, # Stereo audio output (2 channels)
|
||||
"give_pre_end": False,
|
||||
"tanh_out": False,
|
||||
"causality_axis": "height",
|
||||
}
|
||||
|
||||
config = {
|
||||
"_class_name": "CausalAudioAutoencoder",
|
||||
"sampling_rate": 16000,
|
||||
"model": {
|
||||
"params": {
|
||||
"encoder": encoder_config,
|
||||
"decoder": decoder_config,
|
||||
"ddconfig": ddconfig,
|
||||
"sampling_rate": 16000,
|
||||
}
|
||||
},
|
||||
"preprocessing": {
|
||||
"stft": {
|
||||
"filter_length": 1024,
|
||||
"hop_length": 160,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
@@ -15,6 +15,9 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def in_meta_context():
|
||||
return torch.device("meta") == torch.empty(0).device
|
||||
|
||||
def mark_conv3d_ended(module):
|
||||
tid = threading.get_ident()
|
||||
for _, m in module.named_modules():
|
||||
@@ -350,6 +353,10 @@ class Decoder(nn.Module):
|
||||
output_channel = output_channel * block_params.get("multiplier", 2)
|
||||
if block_name == "compress_all":
|
||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||
if block_name == "compress_space":
|
||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||
if block_name == "compress_time":
|
||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims,
|
||||
@@ -395,17 +402,21 @@ class Decoder(nn.Module):
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(2, 1, 1),
|
||||
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(1, 2, 2),
|
||||
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
@@ -455,6 +466,15 @@ class Decoder(nn.Module):
|
||||
output_channel * 2, 0, operations=ops,
|
||||
)
|
||||
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
||||
else:
|
||||
self.register_buffer(
|
||||
"last_scale_shift_table",
|
||||
torch.tensor(
|
||||
[0.0, 0.0],
|
||||
device="cpu" if in_meta_context() else None
|
||||
).unsqueeze(1).expand(2, output_channel),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
@@ -883,6 +903,15 @@ class ResnetBlock3D(nn.Module):
|
||||
self.scale_shift_table = nn.Parameter(
|
||||
torch.randn(4, in_channels) / in_channels**0.5
|
||||
)
|
||||
else:
|
||||
self.register_buffer(
|
||||
"scale_shift_table",
|
||||
torch.tensor(
|
||||
[0.0, 0.0, 0.0, 0.0],
|
||||
device="cpu" if in_meta_context() else None
|
||||
).unsqueeze(1).expand(4, in_channels),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
self.temporal_cache_state={}
|
||||
|
||||
@@ -1012,9 +1041,6 @@ class processor(nn.Module):
|
||||
super().__init__()
|
||||
self.register_buffer("std-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
|
||||
self.register_buffer("channel", torch.empty(128))
|
||||
|
||||
def un_normalize(self, x):
|
||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
@@ -1027,9 +1053,12 @@ class VideoVAE(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = self.guess_config(version)
|
||||
config = self.get_default_config(version)
|
||||
|
||||
self.config = config
|
||||
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
||||
self.decode_noise_scale = config.get("decode_noise_scale", 0.025)
|
||||
self.decode_timestep = config.get("decode_timestep", 0.05)
|
||||
double_z = config.get("double_z", True)
|
||||
latent_log_var = config.get(
|
||||
"latent_log_var", "per_channel" if double_z else "none"
|
||||
@@ -1044,6 +1073,7 @@ class VideoVAE(nn.Module):
|
||||
latent_log_var=latent_log_var,
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||
base_channels=config.get("encoder_base_channels", 128),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
@@ -1051,6 +1081,7 @@ class VideoVAE(nn.Module):
|
||||
in_channels=config["latent_channels"],
|
||||
out_channels=config.get("out_channels", 3),
|
||||
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
||||
base_channels=config.get("decoder_base_channels", 128),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
@@ -1060,7 +1091,7 @@ class VideoVAE(nn.Module):
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def guess_config(self, version):
|
||||
def get_default_config(self, version):
|
||||
if version == 0:
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
@@ -1167,8 +1198,7 @@ class VideoVAE(nn.Module):
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def decode(self, x, timestep=0.05, noise_scale=0.025):
|
||||
def decode(self, x):
|
||||
if self.timestep_conditioning: #TODO: seed
|
||||
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
|
||||
|
||||
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
|
||||
|
||||
@@ -3,6 +3,7 @@ import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
@@ -12,6 +13,307 @@ def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
|
||||
# Adopted from https://github.com/NVIDIA/BigVGAN
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sinc(x: torch.Tensor):
|
||||
return torch.where(
|
||||
x == 0,
|
||||
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
||||
torch.sin(math.pi * x) / math.pi / x,
|
||||
)
|
||||
|
||||
|
||||
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
|
||||
even = kernel_size % 2 == 0
|
||||
half_size = kernel_size // 2
|
||||
delta_f = 4 * half_width
|
||||
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if A > 50.0:
|
||||
beta = 0.1102 * (A - 8.7)
|
||||
elif A >= 21.0:
|
||||
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
||||
else:
|
||||
beta = 0.0
|
||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||
if even:
|
||||
time = torch.arange(-half_size, half_size) + 0.5
|
||||
else:
|
||||
time = torch.arange(kernel_size) - half_size
|
||||
if cutoff == 0:
|
||||
filter_ = torch.zeros_like(time)
|
||||
else:
|
||||
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
|
||||
filter_ /= filter_.sum()
|
||||
filter = filter_.view(1, 1, kernel_size)
|
||||
return filter
|
||||
|
||||
|
||||
class LowPassFilter1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cutoff=0.5,
|
||||
half_width=0.6,
|
||||
stride=1,
|
||||
padding=True,
|
||||
padding_mode="replicate",
|
||||
kernel_size=12,
|
||||
):
|
||||
super().__init__()
|
||||
if cutoff < -0.0:
|
||||
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||
if cutoff > 0.5:
|
||||
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||
self.kernel_size = kernel_size
|
||||
self.even = kernel_size % 2 == 0
|
||||
self.pad_left = kernel_size // 2 - int(self.even)
|
||||
self.pad_right = kernel_size // 2
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.padding_mode = padding_mode
|
||||
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
if self.padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||
return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.stride = ratio
|
||||
|
||||
if window_type == "hann":
|
||||
# Hann-windowed sinc filter — identical to torchaudio.functional.resample
|
||||
# with its default parameters (rolloff=0.99, lowpass_filter_width=6).
|
||||
# Uses replicate boundary padding, matching the reference resampler exactly.
|
||||
rolloff = 0.99
|
||||
lowpass_filter_width = 6
|
||||
width = math.ceil(lowpass_filter_width / rolloff)
|
||||
self.kernel_size = 2 * width * ratio + 1
|
||||
self.pad = width
|
||||
self.pad_left = 2 * width * ratio
|
||||
self.pad_right = self.kernel_size - ratio
|
||||
t = (torch.arange(self.kernel_size) / ratio - width) * rolloff
|
||||
t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width)
|
||||
window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
||||
filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1)
|
||||
else:
|
||||
# Kaiser-windowed sinc filter (BigVGAN default).
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = (
|
||||
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
)
|
||||
filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
||||
)
|
||||
|
||||
self.register_buffer("filter", filter, persistent=persistent)
|
||||
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||
x = self.ratio * F.conv_transpose1d(
|
||||
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
||||
)
|
||||
x = x[..., self.pad_left : -self.pad_right]
|
||||
return x
|
||||
|
||||
|
||||
class DownSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.lowpass = LowPassFilter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
stride=ratio,
|
||||
kernel_size=self.kernel_size,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.lowpass(x)
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
activation,
|
||||
up_ratio=2,
|
||||
down_ratio=2,
|
||||
up_kernel_size=12,
|
||||
down_kernel_size=12,
|
||||
):
|
||||
super().__init__()
|
||||
self.act = activation
|
||||
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BigVGAN v2 activations (Snake / SnakeBeta)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Snake(nn.Module):
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
|
||||
):
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = nn.Parameter(
|
||||
torch.zeros(in_features)
|
||||
if alpha_logscale
|
||||
else torch.ones(in_features) * alpha
|
||||
)
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x):
|
||||
a = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
a = torch.exp(a)
|
||||
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
|
||||
):
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = nn.Parameter(
|
||||
torch.zeros(in_features)
|
||||
if alpha_logscale
|
||||
else torch.ones(in_features) * alpha
|
||||
)
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta = nn.Parameter(
|
||||
torch.zeros(in_features)
|
||||
if alpha_logscale
|
||||
else torch.ones(in_features) * alpha
|
||||
)
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x):
|
||||
a = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||
b = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
a = torch.exp(a)
|
||||
b = torch.exp(b)
|
||||
return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AMPBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"):
|
||||
super().__init__()
|
||||
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.acts1 = nn.ModuleList(
|
||||
[Activation1d(act_cls(channels)) for _ in range(len(self.convs1))]
|
||||
)
|
||||
self.acts2 = nn.ModuleList(
|
||||
[Activation1d(act_cls(channels)) for _ in range(len(self.convs2))]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2):
|
||||
xt = a1(x)
|
||||
xt = c1(xt)
|
||||
xt = a2(xt)
|
||||
xt = c2(xt)
|
||||
x = x + xt
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HiFi-GAN residual blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super(ResBlock1, self).__init__()
|
||||
@@ -119,6 +421,7 @@ class Vocoder(torch.nn.Module):
|
||||
"""
|
||||
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
|
||||
|
||||
Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1").
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
@@ -128,19 +431,39 @@ class Vocoder(torch.nn.Module):
|
||||
config = self.get_default_config()
|
||||
|
||||
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
|
||||
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
|
||||
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
|
||||
upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2])
|
||||
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4])
|
||||
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
|
||||
stereo = config.get("stereo", True)
|
||||
resblock = config.get("resblock", "1")
|
||||
activation = config.get("activation", "snake")
|
||||
use_bias_at_final = config.get("use_bias_at_final", True)
|
||||
|
||||
|
||||
# "output_sample_rate" is not present in recent checkpoint configs.
|
||||
# When absent (None), AudioVAE.output_sample_rate computes it as:
|
||||
# sample_rate * vocoder.upsample_factor / mel_hop_length
|
||||
# where upsample_factor = product of all upsample stride lengths,
|
||||
# and mel_hop_length is loaded from the autoencoder config at
|
||||
# preprocessing.stft.hop_length (see CausalAudioAutoencoder).
|
||||
self.output_sample_rate = config.get("output_sample_rate")
|
||||
self.resblock = config.get("resblock", "1")
|
||||
self.use_tanh_at_final = config.get("use_tanh_at_final", True)
|
||||
self.apply_final_activation = config.get("apply_final_activation", True)
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
|
||||
in_channels = 128 if stereo else 64
|
||||
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
||||
|
||||
if self.resblock == "1":
|
||||
resblock_cls = ResBlock1
|
||||
elif self.resblock == "2":
|
||||
resblock_cls = ResBlock2
|
||||
elif self.resblock == "AMP1":
|
||||
resblock_cls = AMPBlock1
|
||||
else:
|
||||
raise ValueError(f"Unknown resblock type: {self.resblock}")
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
@@ -157,25 +480,40 @@ class Vocoder(torch.nn.Module):
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock_class(ch, k, d))
|
||||
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
||||
if self.resblock == "AMP1":
|
||||
self.resblocks.append(resblock_cls(ch, k, d, activation=activation))
|
||||
else:
|
||||
self.resblocks.append(resblock_cls(ch, k, d))
|
||||
|
||||
out_channels = 2 if stereo else 1
|
||||
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
|
||||
if self.resblock == "AMP1":
|
||||
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||
self.act_post = Activation1d(act_cls(ch))
|
||||
else:
|
||||
self.act_post = nn.LeakyReLU()
|
||||
|
||||
self.conv_post = ops.Conv1d(
|
||||
ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final
|
||||
)
|
||||
|
||||
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
|
||||
|
||||
|
||||
def get_default_config(self):
|
||||
"""Generate default configuration for the vocoder."""
|
||||
|
||||
config = {
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"upsample_rates": [6, 5, 2, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"upsample_rates": [5, 4, 2, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
|
||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"upsample_initial_channel": 1024,
|
||||
"stereo": True,
|
||||
"resblock": "1",
|
||||
"activation": "snake",
|
||||
"use_bias_at_final": True,
|
||||
"use_tanh_at_final": True,
|
||||
}
|
||||
|
||||
return config
|
||||
@@ -196,8 +534,10 @@ class Vocoder(torch.nn.Module):
|
||||
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
||||
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if self.resblock != "AMP1":
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
@@ -206,8 +546,167 @@ class Vocoder(torch.nn.Module):
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
|
||||
x = self.act_post(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
if self.apply_final_activation:
|
||||
if self.use_tanh_at_final:
|
||||
x = torch.tanh(x)
|
||||
else:
|
||||
x = torch.clamp(x, -1, 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class _STFTFn(nn.Module):
|
||||
"""Implements STFT as a convolution with precomputed DFT × Hann-window bases.
|
||||
|
||||
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
|
||||
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
|
||||
bfloat16 bases from training ensures the mel values fed to the BWE generator are
|
||||
bit-identical to what it was trained on.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_length: int, hop_length: int, win_length: int):
|
||||
super().__init__()
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
n_freqs = filter_length // 2 + 1
|
||||
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||
|
||||
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute magnitude and phase spectrogram from a batch of waveforms.
|
||||
|
||||
Applies causal (left-only) padding of win_length - hop_length samples so that
|
||||
each output frame depends only on past and present input — no lookahead.
|
||||
The STFT is computed by convolving the padded signal with forward_basis.
|
||||
|
||||
Args:
|
||||
y: Waveform tensor of shape (B, T).
|
||||
|
||||
Returns:
|
||||
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||
Computed in float32 for numerical stability, then cast back to
|
||||
the input dtype.
|
||||
"""
|
||||
if y.dim() == 2:
|
||||
y = y.unsqueeze(1) # (B, 1, T)
|
||||
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
|
||||
y = F.pad(y, (left_pad, 0))
|
||||
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
|
||||
n_freqs = spec.shape[1] // 2
|
||||
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
||||
magnitude = torch.sqrt(real ** 2 + imag ** 2)
|
||||
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
|
||||
return magnitude, phase
|
||||
|
||||
|
||||
class MelSTFT(nn.Module):
|
||||
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
|
||||
|
||||
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
|
||||
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
|
||||
|
||||
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
|
||||
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter_length: int,
|
||||
hop_length: int,
|
||||
win_length: int,
|
||||
n_mel_channels: int,
|
||||
sampling_rate: int,
|
||||
mel_fmin: float,
|
||||
mel_fmax: float,
|
||||
):
|
||||
super().__init__()
|
||||
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
|
||||
|
||||
n_freqs = filter_length // 2 + 1
|
||||
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
|
||||
|
||||
def mel_spectrogram(
|
||||
self, y: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Compute log-mel spectrogram and auxiliary spectral quantities.
|
||||
|
||||
Args:
|
||||
y: Waveform tensor of shape (B, T).
|
||||
|
||||
Returns:
|
||||
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
|
||||
Computed as log(clamp(mel_basis @ magnitude, min=1e-5)).
|
||||
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
|
||||
"""
|
||||
magnitude, phase = self.stft_fn(y)
|
||||
energy = torch.norm(magnitude, dim=1)
|
||||
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
|
||||
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
return log_mel, magnitude, phase, energy
|
||||
|
||||
|
||||
class VocoderWithBWE(torch.nn.Module):
|
||||
"""Vocoder with bandwidth extension (BWE) for higher sample rate output.
|
||||
|
||||
Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples
|
||||
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
vocoder_config = config["vocoder"]
|
||||
bwe_config = config["bwe"]
|
||||
|
||||
self.vocoder = Vocoder(config=vocoder_config)
|
||||
self.bwe_generator = Vocoder(
|
||||
config={**bwe_config, "apply_final_activation": False}
|
||||
)
|
||||
|
||||
self.input_sample_rate = bwe_config["input_sampling_rate"]
|
||||
self.output_sample_rate = bwe_config["output_sampling_rate"]
|
||||
self.hop_length = bwe_config["hop_length"]
|
||||
|
||||
self.mel_stft = MelSTFT(
|
||||
filter_length=bwe_config["n_fft"],
|
||||
hop_length=bwe_config["hop_length"],
|
||||
win_length=bwe_config["n_fft"],
|
||||
n_mel_channels=bwe_config["num_mels"],
|
||||
sampling_rate=bwe_config["input_sampling_rate"],
|
||||
mel_fmin=0.0,
|
||||
mel_fmax=bwe_config["input_sampling_rate"] / 2.0,
|
||||
)
|
||||
self.resampler = UpSample1d(
|
||||
ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"],
|
||||
persistent=False,
|
||||
window_type="hann",
|
||||
)
|
||||
|
||||
def _compute_mel(self, audio):
|
||||
"""Compute log-mel spectrogram from waveform using causal STFT bases."""
|
||||
B, C, T = audio.shape
|
||||
flat = audio.reshape(B * C, -1) # (B*C, T)
|
||||
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
|
||||
return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
|
||||
|
||||
def forward(self, mel_spec):
|
||||
x = self.vocoder(mel_spec)
|
||||
_, _, T_low = x.shape
|
||||
T_out = T_low * self.output_sample_rate // self.input_sample_rate
|
||||
|
||||
remainder = T_low % self.hop_length
|
||||
if remainder != 0:
|
||||
x = F.pad(x, (0, self.hop_length - remainder))
|
||||
|
||||
mel = self._compute_mel(x)
|
||||
residual = self.bwe_generator(mel)
|
||||
skip = self.resampler(x)
|
||||
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
|
||||
|
||||
return torch.clamp(residual + skip, -1, 1)[..., :T_out]
|
||||
|
||||
@@ -14,6 +14,7 @@ from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
from comfy.ldm.chroma_radiance.layers import NerfEmbedder
|
||||
|
||||
|
||||
def invert_slices(slices, length):
|
||||
@@ -858,3 +859,267 @@ class NextDiT(nn.Module):
|
||||
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
||||
return -img
|
||||
|
||||
|
||||
#############################################################################
|
||||
# Pixel Space Decoder Components #
|
||||
#############################################################################
|
||||
|
||||
def _modulate_shift_scale(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
class PixelResBlock(nn.Module):
|
||||
"""
|
||||
Residual block with AdaLN modulation, zero-initialised so it starts as
|
||||
an identity at the beginning of training.
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_ln = operations.LayerNorm(channels, eps=1e-6, dtype=dtype, device=device)
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1)
|
||||
h = _modulate_shift_scale(self.in_ln(x), shift, scale)
|
||||
h = self.mlp(h)
|
||||
return x + gate * h
|
||||
|
||||
|
||||
class DCTFinalLayer(nn.Module):
|
||||
"""Zero-initialised output projection (adopted from DiT)."""
|
||||
|
||||
def __init__(self, model_channels: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear(self.norm_final(x))
|
||||
|
||||
|
||||
class SimpleMLPAdaLN(nn.Module):
|
||||
"""
|
||||
Small MLP decoder head for the pixel-space variant.
|
||||
|
||||
Takes per-patch pixel values and a per-patch conditioning vector from the
|
||||
transformer backbone and predicts the denoised pixel values.
|
||||
|
||||
x : [B*N, P^2, C] – noisy pixel values per patch position
|
||||
c : [B*N, dim] – backbone hidden state per patch (conditioning)
|
||||
→ [B*N, P^2, C]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
out_channels: int,
|
||||
z_channels: int,
|
||||
num_res_blocks: int,
|
||||
max_freqs: int = 8,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
# Project backbone hidden state → per-patch conditioning
|
||||
self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device)
|
||||
|
||||
# Input projection with DCT positional encoding
|
||||
self.input_embedder = NerfEmbedder(
|
||||
in_channels=in_channels,
|
||||
hidden_size_input=model_channels,
|
||||
max_freqs=max_freqs,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# Residual blocks
|
||||
self.res_blocks = nn.ModuleList([
|
||||
PixelResBlock(model_channels, dtype=dtype, device=device, operations=operations) for _ in range(num_res_blocks)
|
||||
])
|
||||
|
||||
# Output projection
|
||||
self.final_layer = DCTFinalLayer(model_channels, out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
# x: [B*N, 1, P^2*C], c: [B*N, dim]
|
||||
original_dtype = x.dtype
|
||||
weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype)
|
||||
x = self.input_embedder(x) # [B*N, 1, model_channels]
|
||||
y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels]
|
||||
x = x.to(weight_dtype)
|
||||
for block in self.res_blocks:
|
||||
x = block(x, y)
|
||||
return self.final_layer(x).to(original_dtype) # [B*N, 1, P^2*C]
|
||||
|
||||
|
||||
#############################################################################
|
||||
# NextDiT – Pixel Space #
|
||||
#############################################################################
|
||||
|
||||
class NextDiTPixelSpace(NextDiT):
|
||||
"""
|
||||
Pixel-space variant of NextDiT.
|
||||
|
||||
Identical transformer backbone to NextDiT, but the output head is replaced
|
||||
with a small MLP decoder (SimpleMLPAdaLN) that operates on raw pixel values
|
||||
per patch rather than a single affine projection.
|
||||
|
||||
Key differences vs NextDiT:
|
||||
• ``final_layer`` is removed; ``dec_net`` (SimpleMLPAdaLN) is used instead.
|
||||
• ``_forward`` stores the raw patchified pixel values before the backbone
|
||||
embedding and feeds them to ``dec_net`` together with the per-patch
|
||||
backbone hidden states.
|
||||
• Supports optional x0 prediction via ``use_x0``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# decoder-specific
|
||||
decoder_hidden_size: int = 3840,
|
||||
decoder_num_res_blocks: int = 4,
|
||||
decoder_max_freqs: int = 8,
|
||||
decoder_in_channels: int = None, # full flattened patch size (patch_size^2 * in_channels)
|
||||
use_x0: bool = False,
|
||||
# all NextDiT args forwarded unchanged
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Remove the latent-space final layer – not used in pixel space
|
||||
del self.final_layer
|
||||
|
||||
patch_size = kwargs.get("patch_size", 2)
|
||||
in_channels = kwargs.get("in_channels", 4)
|
||||
dim = kwargs.get("dim", 4096)
|
||||
|
||||
# decoder_in_channels is the full flattened patch: patch_size^2 * in_channels
|
||||
dec_in_ch = decoder_in_channels if decoder_in_channels is not None else patch_size ** 2 * in_channels
|
||||
|
||||
self.dec_net = SimpleMLPAdaLN(
|
||||
in_channels=dec_in_ch,
|
||||
model_channels=decoder_hidden_size,
|
||||
out_channels=dec_in_ch,
|
||||
z_channels=dim,
|
||||
num_res_blocks=decoder_num_res_blocks,
|
||||
max_freqs=decoder_max_freqs,
|
||||
dtype=kwargs.get("dtype"),
|
||||
device=kwargs.get("device"),
|
||||
operations=kwargs.get("operations"),
|
||||
)
|
||||
|
||||
if use_x0:
|
||||
self.register_buffer("__x0__", torch.tensor([]))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Forward — mirrors NextDiT._forward exactly, replacing final_layer
|
||||
# with the pixel-space dec_net decoder.
|
||||
# ------------------------------------------------------------------
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs):
|
||||
omni = len(ref_latents) > 0
|
||||
if omni:
|
||||
timesteps = torch.cat([timesteps * 0, timesteps], dim=0)
|
||||
|
||||
t = 1.0 - timesteps
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
bs, c, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
|
||||
t = self.t_embedder(t * self.time_scale, dtype=x.dtype)
|
||||
adaln_input = t
|
||||
|
||||
if self.clip_text_pooled_proj is not None:
|
||||
pooled = kwargs.get("clip_text_pooled", None)
|
||||
if pooled is not None:
|
||||
pooled = self.clip_text_pooled_proj(pooled)
|
||||
else:
|
||||
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
||||
|
||||
# ---- capture raw pixel patches before patchify_and_embed embeds them ----
|
||||
pH = pW = self.patch_size
|
||||
B, C, H, W = x.shape
|
||||
pixel_patches = (
|
||||
x.view(B, C, H // pH, pH, W // pW, pW)
|
||||
.permute(0, 2, 4, 3, 5, 1) # [B, Ht, Wt, pH, pW, C]
|
||||
.flatten(3) # [B, Ht, Wt, pH*pW*C]
|
||||
.flatten(1, 2) # [B, N, pH*pW*C]
|
||||
)
|
||||
N = pixel_patches.shape[1]
|
||||
# decoder sees one token per patch: [B*N, 1, P^2*C]
|
||||
pixel_values = pixel_patches.reshape(B * N, 1, pH * pW * C)
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(
|
||||
x, cap_feats, cap_mask, adaln_input, num_tokens,
|
||||
ref_latents=ref_latents, ref_contexts=ref_contexts,
|
||||
siglip_feats=siglip_feats, transformer_options=transformer_options
|
||||
)
|
||||
freqs_cis = freqs_cis.to(img.device)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.layers)
|
||||
transformer_options["block_type"] = "double"
|
||||
img_input = img
|
||||
for i, layer in enumerate(self.layers):
|
||||
transformer_options["block_index"] = i
|
||||
img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
if "img" in out:
|
||||
img[:, cap_size[0]:] = out["img"]
|
||||
if "txt" in out:
|
||||
img[:, :cap_size[0]] = out["txt"]
|
||||
|
||||
# ---- pixel-space decoder (replaces final_layer + unpatchify) ----
|
||||
# img may have padding tokens beyond N; only the first N are real image patches
|
||||
img_hidden = img[:, cap_size[0]:cap_size[0] + N, :] # [B, N, dim]
|
||||
decoder_cond = img_hidden.reshape(B * N, self.dim) # [B*N, dim]
|
||||
|
||||
output = self.dec_net(pixel_values, decoder_cond) # [B*N, 1, P^2*C]
|
||||
output = output.reshape(B, N, -1) # [B, N, P^2*C]
|
||||
|
||||
# prepend zero cap placeholder so unpatchify indexing works unchanged
|
||||
cap_placeholder = torch.zeros(
|
||||
B, cap_size[0], output.shape[-1], device=output.device, dtype=output.dtype
|
||||
)
|
||||
img_out = self.unpatchify(
|
||||
torch.cat([cap_placeholder, output], dim=1),
|
||||
img_size, cap_size, return_tensor=x_is_tensor
|
||||
)[:, :, :h, :w]
|
||||
|
||||
return -img_out
|
||||
|
||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
# _forward returns neg_x0 = -x0 (negated decoder output).
|
||||
#
|
||||
# Reference inference (working_inference_reference.py):
|
||||
# out = _forward(img, t) # = -x0
|
||||
# pred = (img - out) / t # = (img + x0) / t [_apply_x0_residual]
|
||||
# img += (t_prev - t_curr) * pred # Euler step
|
||||
#
|
||||
# ComfyUI's Euler sampler does the same:
|
||||
# x_next = x + (sigma_next - sigma) * model_output
|
||||
# So model_output must equal pred = (x - neg_x0) / t = (x - (-x0)) / t = (x + x0) / t
|
||||
neg_x0 = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||
|
||||
return (x - neg_x0) / timesteps.view(-1, 1, 1, 1)
|
||||
|
||||
@@ -1021,7 +1021,7 @@ class LTXAV(BaseModel):
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()))
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||
@@ -1263,6 +1263,11 @@ class Lumina2(BaseModel):
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||
return out
|
||||
|
||||
class ZImagePixelSpace(Lumina2):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
class WAN21(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
|
||||
@@ -423,7 +423,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||
return dit_config
|
||||
|
||||
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
||||
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "lumina2"
|
||||
dit_config["patch_size"] = 2
|
||||
@@ -464,6 +464,29 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if sig_weight is not None:
|
||||
dit_config["siglip_feat_dim"] = sig_weight.shape[0]
|
||||
|
||||
dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix)
|
||||
if dec_cond_key in state_dict_keys: # pixel-space variant
|
||||
dit_config["image_model"] = "zimage_pixel"
|
||||
# patch_size and in_channels are derived from x_embedder:
|
||||
# x_embedder: Linear(patch_size * patch_size * in_channels, dim)
|
||||
# The decoder also receives the full flat patch, so decoder_in_channels = x_embedder input dim.
|
||||
x_emb_in = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1]
|
||||
dec_out = state_dict['{}dec_net.final_layer.linear.weight'.format(key_prefix)].shape[0]
|
||||
# patch_size: infer from decoder final layer output matching x_embedder input
|
||||
# in_channels: infer from dec_net input_embedder (in_features = dec_in_ch + max_freqs^2)
|
||||
embedder_w = state_dict['{}dec_net.input_embedder.embedder.0.weight'.format(key_prefix)]
|
||||
dec_in_ch = dec_out # decoder in == decoder out (same pixel space)
|
||||
dit_config["patch_size"] = round((x_emb_in / 3) ** 0.5) # assume RGB (in_channels=3)
|
||||
dit_config["in_channels"] = 3
|
||||
dit_config["decoder_in_channels"] = dec_in_ch
|
||||
dit_config["decoder_hidden_size"] = state_dict[dec_cond_key].shape[0]
|
||||
dit_config["decoder_num_res_blocks"] = count_blocks(
|
||||
state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.'
|
||||
)
|
||||
dit_config["decoder_max_freqs"] = int((embedder_w.shape[1] - dec_in_ch) ** 0.5)
|
||||
if '{}__x0__'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["use_x0"] = True
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
@@ -533,8 +556,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1
|
||||
|
||||
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys and f"{key_prefix}blocks.0.attn1.k_norm.weight" in state_dict_keys: # Hunyuan 3D 2.1
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hunyuan3d2_1"
|
||||
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
|
||||
@@ -1055,6 +1077,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
|
||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
|
||||
elif 'noise_refiner.0.attention.norm_k.weight' in state_dict:
|
||||
n_layers = count_blocks(state_dict, 'layers.{}.')
|
||||
dim = state_dict['noise_refiner.0.attention.to_k.weight'].shape[0]
|
||||
sd_map = comfy.utils.z_image_to_diffusers({"n_layers": n_layers, "dim": dim}, output_prefix=output_prefix)
|
||||
for k in state_dict: # For zeta chroma
|
||||
if k not in sd_map:
|
||||
sd_map[k] = k
|
||||
elif 'x_embedder.weight' in state_dict: #Flux
|
||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
|
||||
@@ -32,9 +32,6 @@ import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.quant_ops
|
||||
|
||||
import comfy_aimdo.torch
|
||||
import comfy_aimdo.model_vbar
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||
@@ -799,6 +796,8 @@ def archive_model_dtypes(model):
|
||||
for name, module in model.named_modules():
|
||||
for param_name, param in module.named_parameters(recurse=False):
|
||||
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
|
||||
for buf_name, buf in module.named_buffers(recurse=False):
|
||||
setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype)
|
||||
|
||||
|
||||
def cleanup_models():
|
||||
@@ -831,11 +830,14 @@ def unet_offload_device():
|
||||
return torch.device("cpu")
|
||||
|
||||
def unet_inital_load_device(parameters, dtype):
|
||||
cpu_dev = torch.device("cpu")
|
||||
if comfy.memory_management.aimdo_enabled:
|
||||
return cpu_dev
|
||||
|
||||
torch_dev = get_torch_device()
|
||||
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
|
||||
return torch_dev
|
||||
|
||||
cpu_dev = torch.device("cpu")
|
||||
if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
|
||||
return cpu_dev
|
||||
|
||||
@@ -843,7 +845,7 @@ def unet_inital_load_device(parameters, dtype):
|
||||
|
||||
mem_dev = get_free_memory(torch_dev)
|
||||
mem_cpu = get_free_memory(cpu_dev)
|
||||
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled:
|
||||
if mem_dev > mem_cpu and model_size < mem_dev:
|
||||
return torch_dev
|
||||
else:
|
||||
return cpu_dev
|
||||
@@ -946,6 +948,9 @@ def text_encoder_device():
|
||||
return torch.device("cpu")
|
||||
|
||||
def text_encoder_initial_device(load_device, offload_device, model_size=0):
|
||||
if comfy.memory_management.aimdo_enabled:
|
||||
return offload_device
|
||||
|
||||
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
|
||||
return offload_device
|
||||
|
||||
@@ -1206,43 +1211,6 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
||||
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||
if hasattr(weight, "_v"):
|
||||
#Unexpected usage patterns. There is no reason these don't work but they
|
||||
#have no testing and no callers do this.
|
||||
assert r is None
|
||||
assert stream is None
|
||||
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ])
|
||||
|
||||
if dtype is None:
|
||||
dtype = weight._model_dtype
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
||||
if signature is not None:
|
||||
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||
v_tensor = weight._v_tensor
|
||||
else:
|
||||
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
||||
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
|
||||
weight._v_tensor = v_tensor
|
||||
weight._v_signature = signature
|
||||
#Send it over
|
||||
v_tensor.copy_(weight, non_blocking=non_blocking)
|
||||
return v_tensor.to(dtype=dtype)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
|
||||
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
|
||||
#Offloaded casting could skip this, however it would make the quantizations
|
||||
#inconsistent between loaded and offloaded weights. So force the double casting
|
||||
#that would happen in regular flow to make offload deterministic.
|
||||
cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
|
||||
cast_buffer.copy_(weight, non_blocking=non_blocking)
|
||||
weight = cast_buffer
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
|
||||
return r
|
||||
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
@@ -1698,12 +1666,16 @@ def lora_compute_dtype(device):
|
||||
return dtype
|
||||
|
||||
def synchronize():
|
||||
if cpu_mode():
|
||||
return
|
||||
if is_intel_xpu():
|
||||
torch.xpu.synchronize()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
if cpu_mode():
|
||||
return
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
torch.mps.empty_cache()
|
||||
|
||||
@@ -241,6 +241,7 @@ class ModelPatcher:
|
||||
|
||||
self.patches = {}
|
||||
self.backup = {}
|
||||
self.backup_buffers = {}
|
||||
self.object_patches = {}
|
||||
self.object_patches_backup = {}
|
||||
self.weight_wrapper_patches = {}
|
||||
@@ -306,10 +307,16 @@ class ModelPatcher:
|
||||
return self.model.lowvram_patch_counter
|
||||
|
||||
def get_free_memory(self, device):
|
||||
return comfy.model_management.get_free_memory(device)
|
||||
#Prioritize batching (incl. CFG/conds etc) over keeping the model resident. In
|
||||
#the vast majority of setups a little bit of offloading on the giant model more
|
||||
#than pays for CFG. So return everything both torch and Aimdo could give us
|
||||
aimdo_mem = 0
|
||||
if comfy.memory_management.aimdo_enabled:
|
||||
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze()
|
||||
return comfy.model_management.get_free_memory(device) + aimdo_mem
|
||||
|
||||
def get_clone_model_override(self):
|
||||
return self.model, (self.backup, self.object_patches_backup, self.pinned)
|
||||
return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)
|
||||
|
||||
def clone(self, disable_dynamic=False, model_override=None):
|
||||
class_ = self.__class__
|
||||
@@ -336,7 +343,7 @@ class ModelPatcher:
|
||||
|
||||
n.force_cast_weights = self.force_cast_weights
|
||||
|
||||
n.backup, n.object_patches_backup, n.pinned = model_override[1]
|
||||
n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1]
|
||||
|
||||
# attachments
|
||||
n.attachments = {}
|
||||
@@ -698,7 +705,7 @@ class ModelPatcher:
|
||||
for key in list(self.pinned):
|
||||
self.unpin_weight(key)
|
||||
|
||||
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
|
||||
def _load_list(self, for_dynamic=False, default_device=None):
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
default = False
|
||||
@@ -726,8 +733,13 @@ class ModelPatcher:
|
||||
return 0
|
||||
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
|
||||
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
|
||||
prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else ()
|
||||
loading.append(prepend + (module_offload_mem, module_mem, n, m, params))
|
||||
# Dynamic: small weights (<64KB) first, then larger weights prioritized by size.
|
||||
# Non-dynamic: prioritize by module offload cost.
|
||||
if for_dynamic:
|
||||
sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem)
|
||||
else:
|
||||
sort_criteria = (module_offload_mem,)
|
||||
loading.append(sort_criteria + (module_mem, n, m, params))
|
||||
return loading
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
@@ -1435,10 +1447,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
|
||||
#this is now way more dynamic and we dont support the same base model for both Dynamic
|
||||
#and non-dynamic patchers.
|
||||
if hasattr(self.model, "model_loaded_weight_memory"):
|
||||
del self.model.model_loaded_weight_memory
|
||||
if not hasattr(self.model, "dynamic_vbars"):
|
||||
self.model.dynamic_vbars = {}
|
||||
self.non_dynamic_delegate_model = None
|
||||
@@ -1461,15 +1469,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
def loaded_size(self):
|
||||
vbar = self._vbar_get()
|
||||
if vbar is None:
|
||||
return 0
|
||||
return vbar.loaded_size()
|
||||
|
||||
def get_free_memory(self, device):
|
||||
#NOTE: on high condition / batch counts, estimate should have already vacated
|
||||
#all non-dynamic models so this is safe even if its not 100% true that this
|
||||
#would all be avaiable for inference use.
|
||||
return comfy.model_management.get_total_memory(device) - self.model_size()
|
||||
return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory
|
||||
|
||||
#Pinning is deferred to ops time. Assert against this API to avoid pin leaks.
|
||||
|
||||
@@ -1504,6 +1504,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
num_patches = 0
|
||||
allocated_size = 0
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
@@ -1512,15 +1513,11 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
if vbar is not None:
|
||||
vbar.prioritize()
|
||||
|
||||
#We force reserve VRAM for the non comfy-weight so we dont have to deal
|
||||
#with pin and unpin syncrhonization which can be expensive for small weights
|
||||
#with a high layer rate (e.g. autoregressive LLMs).
|
||||
#prioritize the non-comfy weights (note the order reverse).
|
||||
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
|
||||
loading.sort(reverse=True)
|
||||
loading = self._load_list(for_dynamic=True, default_device=device_to)
|
||||
loading.sort()
|
||||
|
||||
for x in loading:
|
||||
_, _, _, n, m, params = x
|
||||
*_, module_mem, n, m, params = x
|
||||
|
||||
def set_dirty(item, dirty):
|
||||
if dirty or not hasattr(item, "_v_signature"):
|
||||
@@ -1558,6 +1555,9 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
if key in self.backup:
|
||||
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
if weight is not None:
|
||||
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
|
||||
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.comfy_cast_weights = True
|
||||
@@ -1583,21 +1583,26 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
for param in params:
|
||||
key = key_param_name_to_key(n, param)
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
weight.seed_key = key
|
||||
set_dirty(weight, dirty)
|
||||
geometry = weight
|
||||
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
|
||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||
weight_size = geometry.numel() * geometry.element_size()
|
||||
if vbar is not None and not hasattr(weight, "_v"):
|
||||
weight._v = vbar.alloc(weight_size)
|
||||
weight._model_dtype = model_dtype
|
||||
allocated_size += weight_size
|
||||
vbar.set_watermark_limit(allocated_size)
|
||||
if key not in self.backup:
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False)
|
||||
model_dtype = getattr(m, param + "_comfy_model_dtype", None)
|
||||
casted_weight = weight.to(dtype=model_dtype, device=device_to)
|
||||
comfy.utils.set_attr_param(self.model, key, casted_weight)
|
||||
self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size()
|
||||
|
||||
move_weight_functions(m, device_to)
|
||||
|
||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
||||
for key, buf in self.model.named_buffers(recurse=True):
|
||||
if key not in self.backup_buffers:
|
||||
self.backup_buffers[key] = buf
|
||||
module, buf_name = comfy.utils.resolve_attr(self.model, key)
|
||||
model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None)
|
||||
casted_buf = buf.to(dtype=model_dtype, device=device_to)
|
||||
comfy.utils.set_attr_buffer(self.model, key, casted_buf)
|
||||
self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size()
|
||||
|
||||
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
|
||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
|
||||
|
||||
self.model.device = device_to
|
||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||
@@ -1613,12 +1618,23 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
assert self.load_device != torch.device("cpu")
|
||||
|
||||
vbar = self._vbar_get()
|
||||
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||
freed = 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||
|
||||
if freed < memory_to_free:
|
||||
for key in list(self.backup.keys()):
|
||||
bk = self.backup.pop(key)
|
||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||
for key in list(self.backup_buffers.keys()):
|
||||
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
|
||||
freed += self.model.model_loaded_weight_memory
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
return freed
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
|
||||
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
|
||||
for x in loading:
|
||||
_, _, _, _, m, _ = x
|
||||
*_, m, _ = x
|
||||
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
||||
if ram_to_unload <= 0:
|
||||
return
|
||||
@@ -1640,11 +1656,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
for m in self.model.modules():
|
||||
move_weight_functions(m, device_to)
|
||||
|
||||
keys = list(self.backup.keys())
|
||||
for k in keys:
|
||||
bk = self.backup[k]
|
||||
comfy.utils.set_attr_param(self.model, k, bk.weight)
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
assert not force_patch_weights #See above
|
||||
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
||||
|
||||
37
comfy/ops.py
37
comfy/ops.py
@@ -80,6 +80,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
|
||||
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||
|
||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
||||
#If you are a custom node author reading this, please move your layer to the GPU
|
||||
#or declare your ModelPatcher as CPU in the first place.
|
||||
if comfy.model_management.is_device_cpu(device):
|
||||
weight = s.weight.to(dtype=dtype, copy=True)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
bias = None
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(dtype=bias_dtype, copy=True)
|
||||
return weight, bias, (None, None, None)
|
||||
|
||||
offload_stream = None
|
||||
xfer_dest = None
|
||||
|
||||
@@ -269,8 +284,8 @@ def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||
return
|
||||
os, weight_a, bias_a = offload_stream
|
||||
device=None
|
||||
#FIXME: This is not good RTTI
|
||||
if not isinstance(weight_a, torch.Tensor):
|
||||
#FIXME: This is really bad RTTI
|
||||
if weight_a is not None and not isinstance(weight_a, torch.Tensor):
|
||||
comfy_aimdo.model_vbar.vbar_unpin(s._v)
|
||||
device = weight_a
|
||||
if os is None:
|
||||
@@ -660,23 +675,29 @@ class fp8_ops(manual_cast):
|
||||
|
||||
CUBLAS_IS_AVAILABLE = False
|
||||
try:
|
||||
from cublas_ops import CublasLinear
|
||||
from cublas_ops import CublasLinear, cublas_half_matmul
|
||||
CUBLAS_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if CUBLAS_IS_AVAILABLE:
|
||||
class cublas_ops(disable_weight_init):
|
||||
class Linear(CublasLinear, disable_weight_init.Linear):
|
||||
class cublas_ops(manual_cast):
|
||||
class Linear(CublasLinear, manual_cast.Linear):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
return super().forward(input)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = cublas_half_matmul(input, weight, bias, self._epilogue_str, self.has_bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
# ==============================================================================
|
||||
# Mixed Precision Operations
|
||||
|
||||
@@ -428,7 +428,7 @@ class CLIP:
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
self.load_model()
|
||||
self.load_model(tokens)
|
||||
self.cond_stage_model.set_clip_options({"layer": None})
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
||||
@@ -1467,7 +1467,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
||||
elif clip_type == CLIPType.LTXV:
|
||||
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data))
|
||||
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data), **comfy.text_encoders.lt.sd_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif clip_type == CLIPType.NEWBIE:
|
||||
|
||||
@@ -1118,6 +1118,20 @@ class ZImage(Lumina2):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
|
||||
|
||||
class ZImagePixelSpace(ZImage):
|
||||
unet_config = {
|
||||
"image_model": "zimage_pixel",
|
||||
}
|
||||
|
||||
# Pixel-space model: no spatial compression, operates on raw RGB patches.
|
||||
latent_format = latent_formats.ZImagePixelSpace
|
||||
|
||||
# Much lower memory than latent-space models (no VAE, small patches).
|
||||
memory_usage_factor = 0.03 # TODO: figure out the optimal value for this.
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.ZImagePixelSpace(self, device=device)
|
||||
|
||||
class WAN21_T2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@@ -1720,6 +1734,6 @@ class LongCatImage(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -97,18 +97,39 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
|
||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
|
||||
|
||||
class DualLinearProjection(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.audio_aggregate_embed = operations.Linear(in_dim, out_dim_audio, bias=True, dtype=dtype, device=device)
|
||||
self.video_aggregate_embed = operations.Linear(in_dim, out_dim_video, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
source_dim = x.shape[-1]
|
||||
x = x.movedim(1, -1)
|
||||
x = (x * torch.rsqrt(torch.mean(x**2, dim=2, keepdim=True) + 1e-6)).flatten(start_dim=2)
|
||||
|
||||
video = self.video_aggregate_embed(x * math.sqrt(self.video_aggregate_embed.out_features / source_dim))
|
||||
audio = self.audio_aggregate_embed(x * math.sqrt(self.audio_aggregate_embed.out_features / source_dim))
|
||||
return torch.cat((video, audio), dim=-1)
|
||||
|
||||
class LTXAVTEModel(torch.nn.Module):
|
||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, text_projection_type="single_linear", model_options={}):
|
||||
super().__init__()
|
||||
self.dtypes = set()
|
||||
self.dtypes.add(dtype)
|
||||
self.compat_mode = False
|
||||
self.text_projection_type = text_projection_type
|
||||
|
||||
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
|
||||
self.dtypes.add(dtype_llama)
|
||||
|
||||
operations = self.gemma3_12b.operations # TODO
|
||||
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
||||
|
||||
if self.text_projection_type == "single_linear":
|
||||
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
||||
elif self.text_projection_type == "dual_linear":
|
||||
self.text_embedding_projection = DualLinearProjection(3840 * 49, 4096, 2048, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
|
||||
def enable_compat_mode(self): # TODO: remove
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
@@ -148,18 +169,25 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
out_device = out.device
|
||||
if comfy.model_management.should_use_bf16(self.execution_device):
|
||||
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
||||
out = out.movedim(1, -1).to(self.execution_device)
|
||||
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
|
||||
out = out.reshape((out.shape[0], out.shape[1], -1))
|
||||
out = self.text_embedding_projection(out)
|
||||
out = out.float()
|
||||
|
||||
if self.compat_mode:
|
||||
out_vid = self.video_embeddings_connector(out)[0]
|
||||
out_audio = self.audio_embeddings_connector(out)[0]
|
||||
out = torch.concat((out_vid, out_audio), dim=-1)
|
||||
if self.text_projection_type == "single_linear":
|
||||
out = out.movedim(1, -1).to(self.execution_device)
|
||||
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
|
||||
out = out.reshape((out.shape[0], out.shape[1], -1))
|
||||
out = self.text_embedding_projection(out)
|
||||
|
||||
return out.to(out_device), pooled
|
||||
if self.compat_mode:
|
||||
out_vid = self.video_embeddings_connector(out)[0]
|
||||
out_audio = self.audio_embeddings_connector(out)[0]
|
||||
out = torch.concat((out_vid, out_audio), dim=-1)
|
||||
extra = {}
|
||||
else:
|
||||
extra = {"unprocessed_ltxav_embeds": True}
|
||||
elif self.text_projection_type == "dual_linear":
|
||||
out = self.text_embedding_projection(out)
|
||||
extra = {"unprocessed_ltxav_embeds": True}
|
||||
|
||||
return out.to(device=out_device, dtype=torch.float), pooled, extra
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
||||
@@ -168,7 +196,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||
return self.gemma3_12b.load_sd(sd)
|
||||
else:
|
||||
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True)
|
||||
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "text_embedding_projection.": "text_embedding_projection."}, filter_keys=True)
|
||||
if len(sdo) == 0:
|
||||
sdo = sd
|
||||
|
||||
@@ -206,7 +234,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
num_tokens = max(num_tokens, 642)
|
||||
return num_tokens * constant * 1024 * 1024
|
||||
|
||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None, text_projection_type="single_linear"):
|
||||
class LTXAVTEModel_(LTXAVTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
@@ -214,9 +242,19 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, text_projection_type=text_projection_type, model_options=model_options)
|
||||
return LTXAVTEModel_
|
||||
|
||||
|
||||
def sd_detect(state_dict_list, prefix=""):
|
||||
for sd in state_dict_list:
|
||||
if "{}text_embedding_projection.audio_aggregate_embed.bias".format(prefix) in sd:
|
||||
return {"text_projection_type": "dual_linear"}
|
||||
if "{}text_embedding_projection.weight".format(prefix) in sd or "{}text_embedding_projection.aggregate_embed.weight".format(prefix) in sd:
|
||||
return {"text_projection_type": "single_linear"}
|
||||
return {}
|
||||
|
||||
|
||||
def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class Gemma3_12BModel_(Gemma3_12BModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
|
||||
@@ -869,20 +869,31 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
||||
|
||||
ATTR_UNSET={}
|
||||
|
||||
def set_attr(obj, attr, value):
|
||||
def resolve_attr(obj, attr):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1], ATTR_UNSET)
|
||||
return obj, attrs[-1]
|
||||
|
||||
def set_attr(obj, attr, value):
|
||||
obj, name = resolve_attr(obj, attr)
|
||||
prev = getattr(obj, name, ATTR_UNSET)
|
||||
if value is ATTR_UNSET:
|
||||
delattr(obj, attrs[-1])
|
||||
delattr(obj, name)
|
||||
else:
|
||||
setattr(obj, attrs[-1], value)
|
||||
setattr(obj, name, value)
|
||||
return prev
|
||||
|
||||
def set_attr_param(obj, attr, value):
|
||||
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
||||
|
||||
def set_attr_buffer(obj, attr, value):
|
||||
obj, name = resolve_attr(obj, attr)
|
||||
prev = getattr(obj, name, ATTR_UNSET)
|
||||
persistent = name not in getattr(obj, "_non_persistent_buffers_set", set())
|
||||
obj.register_buffer(name, value, persistent=persistent)
|
||||
return prev
|
||||
|
||||
def copy_to_param(obj, attr, value):
|
||||
# inplace update tensor instead of replacing it
|
||||
attrs = attr.split(".")
|
||||
|
||||
@@ -401,6 +401,7 @@ class VideoFromComponents(VideoInput):
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None,
|
||||
):
|
||||
"""Save the video to a file path or BytesIO buffer."""
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
@@ -408,6 +409,10 @@ class VideoFromComponents(VideoInput):
|
||||
extra_kwargs = {}
|
||||
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
|
||||
extra_kwargs["format"] = format.value
|
||||
elif isinstance(path, io.BytesIO):
|
||||
# BytesIO has no file extension, so av.open can't infer the format.
|
||||
# Default to mp4 since that's the only supported format anyway.
|
||||
extra_kwargs["format"] = "mp4"
|
||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
|
||||
# Add metadata before writing any streams
|
||||
if metadata is not None:
|
||||
|
||||
@@ -1240,6 +1240,19 @@ class BoundingBox(ComfyTypeIO):
|
||||
return d
|
||||
|
||||
|
||||
@comfytype(io_type="CURVE")
|
||||
class Curve(ComfyTypeIO):
|
||||
CurvePoint = tuple[float, float]
|
||||
Type = list[CurvePoint]
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: list[tuple[float, float]]=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||
if default is None:
|
||||
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
||||
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||
@@ -2226,5 +2239,6 @@ __all__ = [
|
||||
"PriceBadgeDepends",
|
||||
"PriceBadge",
|
||||
"BoundingBox",
|
||||
"Curve",
|
||||
"NodeReplace",
|
||||
]
|
||||
|
||||
@@ -7,7 +7,8 @@ class ImageGenerationRequest(BaseModel):
|
||||
aspect_ratio: str = Field(...)
|
||||
n: int = Field(...)
|
||||
seed: int = Field(...)
|
||||
response_for: str = Field("url")
|
||||
response_format: str = Field("url")
|
||||
resolution: str = Field(...)
|
||||
|
||||
|
||||
class InputUrlObject(BaseModel):
|
||||
@@ -16,12 +17,13 @@ class InputUrlObject(BaseModel):
|
||||
|
||||
class ImageEditRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
image: InputUrlObject = Field(...)
|
||||
images: list[InputUrlObject] = Field(...)
|
||||
prompt: str = Field(...)
|
||||
resolution: str = Field(...)
|
||||
n: int = Field(...)
|
||||
seed: int = Field(...)
|
||||
response_for: str = Field("url")
|
||||
response_format: str = Field("url")
|
||||
aspect_ratio: str | None = Field(...)
|
||||
|
||||
|
||||
class VideoGenerationRequest(BaseModel):
|
||||
@@ -47,8 +49,13 @@ class ImageResponseObject(BaseModel):
|
||||
revised_prompt: str | None = Field(None)
|
||||
|
||||
|
||||
class UsageObject(BaseModel):
|
||||
cost_in_usd_ticks: int | None = Field(None)
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseModel):
|
||||
data: list[ImageResponseObject] = Field(...)
|
||||
usage: UsageObject | None = Field(None)
|
||||
|
||||
|
||||
class VideoGenerationResponse(BaseModel):
|
||||
@@ -65,3 +72,4 @@ class VideoStatusResponse(BaseModel):
|
||||
status: str | None = Field(None)
|
||||
video: VideoResponseObject | None = Field(None)
|
||||
model: str | None = Field(None)
|
||||
usage: UsageObject | None = Field(None)
|
||||
|
||||
@@ -148,3 +148,4 @@ class MotionControlRequest(BaseModel):
|
||||
keep_original_sound: str = Field(...)
|
||||
character_orientation: str = Field(...)
|
||||
mode: str = Field(..., description="'pro' or 'std'")
|
||||
model_name: str = Field(...)
|
||||
|
||||
@@ -27,6 +27,12 @@ from comfy_api_nodes.util import (
|
||||
)
|
||||
|
||||
|
||||
def _extract_grok_price(response) -> float | None:
|
||||
if response.usage and response.usage.cost_in_usd_ticks is not None:
|
||||
return response.usage.cost_in_usd_ticks / 10_000_000_000
|
||||
return None
|
||||
|
||||
|
||||
class GrokImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@@ -37,7 +43,10 @@ class GrokImageNode(IO.ComfyNode):
|
||||
category="api node/image/Grok",
|
||||
description="Generate images using Grok based on a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["grok-imagine-image-beta"]),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"],
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
@@ -81,6 +90,7 @@ class GrokImageNode(IO.ComfyNode):
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1K", "2K"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
@@ -92,8 +102,13 @@ class GrokImageNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]),
|
||||
expr="""{"type":"usd","usd":0.033 * widgets.number_of_images}""",
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]),
|
||||
expr="""
|
||||
(
|
||||
$rate := $contains(widgets.model, "pro") ? 0.07 : 0.02;
|
||||
{"type":"usd","usd": $rate * widgets.number_of_images}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -105,6 +120,7 @@ class GrokImageNode(IO.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
number_of_images: int,
|
||||
seed: int,
|
||||
resolution: str = "1K",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
response = await sync_op(
|
||||
@@ -116,8 +132,10 @@ class GrokImageNode(IO.ComfyNode):
|
||||
aspect_ratio=aspect_ratio,
|
||||
n=number_of_images,
|
||||
seed=seed,
|
||||
resolution=resolution.lower(),
|
||||
),
|
||||
response_model=ImageGenerationResponse,
|
||||
price_extractor=_extract_grok_price,
|
||||
)
|
||||
if len(response.data) == 1:
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
|
||||
@@ -138,14 +156,17 @@ class GrokImageEditNode(IO.ComfyNode):
|
||||
category="api node/image/Grok",
|
||||
description="Modify an existing image based on a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["grok-imagine-image-beta"]),
|
||||
IO.Image.Input("image"),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"],
|
||||
),
|
||||
IO.Image.Input("image", display_name="images"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="The text prompt used to generate the image",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1K"]),
|
||||
IO.Combo.Input("resolution", options=["1K", "2K"]),
|
||||
IO.Int.Input(
|
||||
"number_of_images",
|
||||
default=1,
|
||||
@@ -166,6 +187,27 @@ class GrokImageEditNode(IO.ComfyNode):
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[
|
||||
"auto",
|
||||
"1:1",
|
||||
"2:3",
|
||||
"3:2",
|
||||
"3:4",
|
||||
"4:3",
|
||||
"9:16",
|
||||
"16:9",
|
||||
"9:19.5",
|
||||
"19.5:9",
|
||||
"9:20",
|
||||
"20:9",
|
||||
"1:2",
|
||||
"2:1",
|
||||
],
|
||||
optional=True,
|
||||
tooltip="Only allowed when multiple images are connected to the image input.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
@@ -177,8 +219,13 @@ class GrokImageEditNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]),
|
||||
expr="""{"type":"usd","usd":0.002 + 0.033 * widgets.number_of_images}""",
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]),
|
||||
expr="""
|
||||
(
|
||||
$rate := $contains(widgets.model, "pro") ? 0.07 : 0.02;
|
||||
{"type":"usd","usd": 0.002 + $rate * widgets.number_of_images}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -191,22 +238,32 @@ class GrokImageEditNode(IO.ComfyNode):
|
||||
resolution: str,
|
||||
number_of_images: int,
|
||||
seed: int,
|
||||
aspect_ratio: str = "auto",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Only one input image is supported.")
|
||||
if model == "grok-imagine-image-pro":
|
||||
if get_number_of_images(image) > 1:
|
||||
raise ValueError("The pro model supports only 1 input image.")
|
||||
elif get_number_of_images(image) > 3:
|
||||
raise ValueError("A maximum of 3 input images is supported.")
|
||||
if aspect_ratio != "auto" and get_number_of_images(image) == 1:
|
||||
raise ValueError(
|
||||
"Custom aspect ratio is only allowed when multiple images are connected to the image input."
|
||||
)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"),
|
||||
data=ImageEditRequest(
|
||||
model=model,
|
||||
image=InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}"),
|
||||
images=[InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(i)}") for i in image],
|
||||
prompt=prompt,
|
||||
resolution=resolution.lower(),
|
||||
n=number_of_images,
|
||||
seed=seed,
|
||||
aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio,
|
||||
),
|
||||
response_model=ImageGenerationResponse,
|
||||
price_extractor=_extract_grok_price,
|
||||
)
|
||||
if len(response.data) == 1:
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
|
||||
@@ -227,7 +284,7 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
category="api node/video/Grok",
|
||||
description="Generate video from a prompt or an image",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["grok-imagine-video-beta"]),
|
||||
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
@@ -275,10 +332,11 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration"], inputs=["image"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]),
|
||||
expr="""
|
||||
(
|
||||
$base := 0.181 * widgets.duration;
|
||||
$rate := widgets.resolution = "720p" ? 0.07 : 0.05;
|
||||
$base := $rate * widgets.duration;
|
||||
{"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base}
|
||||
)
|
||||
""",
|
||||
@@ -321,6 +379,7 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||
response_model=VideoStatusResponse,
|
||||
price_extractor=_extract_grok_price,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
@@ -335,7 +394,7 @@ class GrokVideoEditNode(IO.ComfyNode):
|
||||
category="api node/video/Grok",
|
||||
description="Edit an existing video based on a text prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["grok-imagine-video-beta"]),
|
||||
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
@@ -364,7 +423,7 @@ class GrokVideoEditNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd": 0.191, "format": {"suffix": "/sec", "approximate": true}}""",
|
||||
expr="""{"type":"usd","usd": 0.06, "format": {"suffix": "/sec", "approximate": true}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -398,6 +457,7 @@ class GrokVideoEditNode(IO.ComfyNode):
|
||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||
response_model=VideoStatusResponse,
|
||||
price_extractor=_extract_grok_price,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
@@ -2747,6 +2747,7 @@ class MotionControl(IO.ComfyNode):
|
||||
"but the character orientation matches the reference image (camera/other details via prompt).",
|
||||
),
|
||||
IO.Combo.Input("mode", options=["pro", "std"]),
|
||||
IO.Combo.Input("model", options=["kling-v3", "kling-v2-6"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@@ -2777,6 +2778,7 @@ class MotionControl(IO.ComfyNode):
|
||||
keep_original_sound: bool,
|
||||
character_orientation: str,
|
||||
mode: str,
|
||||
model: str = "kling-v2-6",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=2500)
|
||||
validate_image_dimensions(reference_image, min_width=340, min_height=340)
|
||||
@@ -2797,6 +2799,7 @@ class MotionControl(IO.ComfyNode):
|
||||
keep_original_sound="yes" if keep_original_sound else "no",
|
||||
character_orientation=character_orientation,
|
||||
mode=mode,
|
||||
model_name=model,
|
||||
),
|
||||
)
|
||||
if response.code:
|
||||
|
||||
@@ -96,7 +96,7 @@ class VAEEncodeAudio(IO.ComfyNode):
|
||||
|
||||
def vae_decode_audio(vae, samples, tile=None, overlap=None):
|
||||
if tile is not None:
|
||||
audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1)
|
||||
audio = vae.decode_tiled(samples["samples"], tile_x=tile, tile_y=tile, overlap=overlap).movedim(-1, 1)
|
||||
else:
|
||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.15.1"
|
||||
__version__ = "0.16.2"
|
||||
|
||||
14
execution.py
14
execution.py
@@ -876,12 +876,14 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
# Unwraps values wrapped in __value__ key. This is used to pass
|
||||
# list widget value to execution, as by default list value is
|
||||
# reserved to represent the connection between nodes.
|
||||
if isinstance(val, dict) and "__value__" in val:
|
||||
val = val["__value__"]
|
||||
inputs[x] = val
|
||||
# Unwraps values wrapped in __value__ key or typed wrapper.
|
||||
# This is used to pass list widget values to execution,
|
||||
# as by default list value is reserved to represent the
|
||||
# connection between nodes.
|
||||
if isinstance(val, dict):
|
||||
if "__value__" in val:
|
||||
val = val["__value__"]
|
||||
inputs[x] = val
|
||||
|
||||
if input_type == "INT":
|
||||
val = int(val)
|
||||
|
||||
10
main.py
10
main.py
@@ -16,11 +16,6 @@ from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_api import feature_flags
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
if enables_dynamic_vram():
|
||||
comfy_aimdo.control.init()
|
||||
|
||||
if __name__ == "__main__":
|
||||
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
||||
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
||||
@@ -28,6 +23,11 @@ if __name__ == "__main__":
|
||||
|
||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
if enables_dynamic_vram():
|
||||
comfy_aimdo.control.init()
|
||||
|
||||
if os.name == "nt":
|
||||
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.15.1"
|
||||
version = "0.16.2"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.39.19
|
||||
comfyui-workflow-templates==0.9.5
|
||||
comfyui-workflow-templates==0.9.10
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
torchsde
|
||||
@@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.7
|
||||
comfy-aimdo>=0.2.4
|
||||
comfy-aimdo>=0.2.7
|
||||
requests
|
||||
|
||||
#non essential dependencies:
|
||||
|
||||
Reference in New Issue
Block a user