Merge branch 'master' into dr-support-pip-cm

This commit is contained in:
Dr.Lt.Data
2025-08-23 17:46:43 +09:00
19 changed files with 667 additions and 29 deletions

View File

@@ -19,6 +19,7 @@ import torch
from torch import nn
import comfy.model_management
import comfy.patcher_extension
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from .attention import LinearTransformerBlock, t2i_modulate
@@ -343,7 +344,28 @@ class ACEStepTransformer2DModel(nn.Module):
output = self.final_layer(hidden_states, embedded_timestep, output_length)
return output
def forward(
def forward(self,
x,
timestep,
attention_mask=None,
context: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
lyrics_strength=1.0,
**kwargs
):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x, timestep, attention_mask, context, text_attention_mask, speaker_embeds, lyric_token_idx, lyric_mask, block_controlnet_hidden_states,
controlnet_scale, lyrics_strength, **kwargs)
def _forward(
self,
x,
timestep,

View File

@@ -9,6 +9,7 @@ import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
import comfy.patcher_extension
import comfy.ldm.common_dit
def modulate(x, shift, scale):
@@ -436,6 +437,13 @@ class MMDiT(nn.Module):
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, transformer_options, **kwargs)
def _forward(self, x, timestep, context, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
# patchify x, add PE
b, c, h, w = x.shape

View File

@@ -5,6 +5,7 @@ from dataclasses import dataclass
import torch
from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.patcher_extension
import comfy.ldm.common_dit
from comfy.ldm.flux.layers import (
@@ -253,6 +254,13 @@ class Chroma(nn.Module):
return img
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, guidance, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))

View File

@@ -27,6 +27,8 @@ from torchvision import transforms
from enum import Enum
import logging
import comfy.patcher_extension
from .blocks import (
FinalLayer,
GeneralDITTransformerBlock,
@@ -435,6 +437,42 @@ class GeneralDIT(nn.Module):
latent_condition_sigma: Optional[torch.Tensor] = None,
condition_video_augment_sigma: Optional[torch.Tensor] = None,
**kwargs,
):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x,
timesteps,
context,
attention_mask,
fps,
image_size,
padding_mask,
scalar_feature,
data_type,
latent_condition,
latent_condition_sigma,
condition_video_augment_sigma,
**kwargs)
def _forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
# crossattn_emb: torch.Tensor,
# crossattn_mask: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
image_size: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
scalar_feature: Optional[torch.Tensor] = None,
data_type: Optional[DataType] = DataType.VIDEO,
latent_condition: Optional[torch.Tensor] = None,
latent_condition_sigma: Optional[torch.Tensor] = None,
condition_video_augment_sigma: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Args:

View File

@@ -11,6 +11,7 @@ import math
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
from torchvision import transforms
import comfy.patcher_extension
from comfy.ldm.modules.attention import optimized_attention
def apply_rotary_pos_emb(
@@ -805,7 +806,21 @@ class MiniTrainDIT(nn.Module):
)
return x_B_C_Tt_Hp_Wp
def forward(
def forward(self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
**kwargs,
):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x, timesteps, context, fps, padding_mask, **kwargs)
def _forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,

View File

@@ -6,6 +6,7 @@ import torch
from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
import comfy.patcher_extension
from .layers import (
DoubleStreamBlock,
@@ -214,6 +215,13 @@ class Flux(nn.Module):
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, guidance, ref_latents, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
bs, c, h_orig, w_orig = x.shape
patch_size = self.patch_size

View File

@@ -13,6 +13,7 @@ from comfy.ldm.flux.layers import LastLayer
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
import comfy.patcher_extension
import comfy.ldm.common_dit
@@ -692,7 +693,23 @@ class HiDreamImageTransformer2DModel(nn.Module):
raise NotImplementedError
return x, x_masks, img_sizes
def forward(
def forward(self,
x: torch.Tensor,
t: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
encoder_hidden_states_llama3=None,
image_cond=None,
control = None,
transformer_options = {},
):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, t, y, context, encoder_hidden_states_llama3, image_cond, control, transformer_options)
def _forward(
self,
x: torch.Tensor,
t: torch.Tensor,

View File

@@ -7,6 +7,7 @@ from comfy.ldm.flux.layers import (
SingleStreamBlock,
timestep_embedding,
)
import comfy.patcher_extension
class Hunyuan3Dv2(nn.Module):
@@ -67,6 +68,13 @@ class Hunyuan3Dv2(nn.Module):
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, guidance, transformer_options, **kwargs)
def _forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
x = x.movedim(-1, -2)
timestep = 1.0 - timestep
txt = context

View File

@@ -1,6 +1,7 @@
#Based on Flux code because of weird hunyuan video code license.
import torch
import comfy.patcher_extension
import comfy.ldm.flux.layers
import comfy.ldm.modules.diffusionmodules.mmdit
from comfy.ldm.modules.attention import optimized_attention
@@ -348,6 +349,13 @@ class HunyuanVideo(nn.Module):
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)

View File

@@ -1,5 +1,6 @@
import torch
from torch import nn
import comfy.patcher_extension
import comfy.ldm.modules.attention
import comfy.ldm.common_dit
from einops import rearrange
@@ -420,6 +421,13 @@ class LTXVModel(torch.nn.Module):
self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
orig_shape = list(x.shape)

View File

@@ -11,6 +11,7 @@ import comfy.ldm.common_dit
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
import comfy.patcher_extension
def modulate(x, scale):
@@ -590,8 +591,15 @@ class NextDiT(nn.Module):
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
# def forward(self, x, t, cap_feats, cap_mask):
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
# def forward(self, x, t, cap_feats, cap_mask):
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask

View File

@@ -109,7 +109,7 @@ class PatchEmbed(nn.Module):
def modulate(x, shift, scale):
if shift is None:
shift = torch.zeros_like(scale)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return torch.addcmul(shift.unsqueeze(1), x, 1+ scale.unsqueeze(1))
#################################################################################
@@ -564,10 +564,7 @@ class DismantledBlock(nn.Module):
assert not self.pre_only
attn1 = self.attn.post_attention(attn)
attn2 = self.attn2.post_attention(attn2)
out1 = gate_msa.unsqueeze(1) * attn1
out2 = gate_msa2.unsqueeze(1) * attn2
x = x + out1
x = x + out2
x = gate_cat(x, gate_msa, gate_msa2, attn1, attn2)
x = x + gate_mlp.unsqueeze(1) * self.mlp(
modulate(self.norm2(x), shift_mlp, scale_mlp)
)
@@ -594,6 +591,11 @@ class DismantledBlock(nn.Module):
)
return self.post_attention(attn, *intermediates)
def gate_cat(x, gate_msa, gate_msa2, attn1, attn2):
out1 = gate_msa.unsqueeze(1) * attn1
out2 = gate_msa2.unsqueeze(1) * attn2
x = torch.stack([x, out1, out2], dim=0).sum(dim=0)
return x
def block_mixing(*args, use_checkpoint=True, **kwargs):
if use_checkpoint:

View File

@@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
import comfy.ldm.common_dit
import comfy.patcher_extension
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
@@ -214,9 +215,9 @@ class QwenImageTransformerBlock(nn.Module):
operations=operations,
)
def _modulate(self, x, mod_params):
shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
def forward(
self,
@@ -248,11 +249,11 @@ class QwenImageTransformerBlock(nn.Module):
img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
txt_normed2 = self.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
return encoder_hidden_states, hidden_states
@@ -275,7 +276,7 @@ class LastLayer(nn.Module):
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(conditioning_embedding))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :])
return x
@@ -355,7 +356,14 @@ class QwenImageTransformer2DModel(nn.Module):
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
def forward(
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
def _forward(
self,
x,
timesteps,

View File

@@ -11,6 +11,7 @@ from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.ldm.common_dit
import comfy.model_management
import comfy.patcher_extension
def sinusoidal_embedding_1d(dim, position):
@@ -148,8 +149,8 @@ WAN_CROSSATTENTION_CLASSES = {
def repeat_e(e, x):
repeats = 1
if e.shape[1] > 1:
repeats = x.shape[1] // e.shape[1]
if e.size(1) > 1:
repeats = x.size(1) // e.size(1)
if repeats == 1:
return e
return torch.repeat_interleave(e, repeats, dim=1)
@@ -219,15 +220,15 @@ class WanAttentionBlock(nn.Module):
# self-attention
y = self.self_attn(
self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x),
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs)
x = x + y * repeat_e(e[2], x)
x = torch.addcmul(x, y, repeat_e(e[2], x))
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x))
x = x + y * repeat_e(e[5], x)
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
return x
@@ -342,7 +343,7 @@ class Head(nn.Module):
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x)))
x = (self.head(torch.addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x))))
return x
@@ -573,6 +574,13 @@ class WanModel(torch.nn.Module):
return x
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, clip_fea, time_dim_concat, transformer_options, **kwargs)
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)

View File

@@ -50,6 +50,7 @@ class WrappersMP:
OUTER_SAMPLE = "outer_sample"
PREPARE_SAMPLING = "prepare_sampling"
SAMPLER_SAMPLE = "sampler_sample"
PREDICT_NOISE = "predict_noise"
CALC_COND_BATCH = "calc_cond_batch"
APPLY_MODEL = "apply_model"
DIFFUSION_MODEL = "diffusion_model"

View File

@@ -17,6 +17,7 @@ import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
import comfy.context_windows
import comfy.utils
import scipy.stats
import numpy
@@ -61,7 +62,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
assert (mask.shape[1:] == x_in.shape[2:])
# assert (mask.shape[1:] == x_in.shape[2:])
mask = mask[:input_x.shape[0]]
if area is not None:
@@ -69,7 +70,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
mask = mask.narrow(i + 1, area[len(dims) + i], area[i])
mask = mask * mask_strength
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
mask = mask.unsqueeze(1).repeat((input_x.shape[0] // mask.shape[0], input_x.shape[1]) + (1, ) * (mask.ndim - 1))
else:
mask = torch.ones_like(input_x)
mult = mask * strength
@@ -553,7 +554,10 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
if len(mask.shape) == len(dims):
mask = mask.unsqueeze(0)
if mask.shape[1:] != dims:
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1)
if mask.ndim < 4:
mask = comfy.utils.common_upscale(mask.unsqueeze(1), dims[-1], dims[-2], 'bilinear', 'none').squeeze(1)
else:
mask = comfy.utils.common_upscale(mask, dims[-1], dims[-2], 'bilinear', 'none')
if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
@@ -953,7 +957,14 @@ class CFGGuider:
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
def __call__(self, *args, **kwargs):
return self.predict_noise(*args, **kwargs)
return self.outer_predict_noise(*args, **kwargs)
def outer_predict_noise(self, x, timestep, model_options={}, seed=None):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self.predict_noise,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, self.model_options, is_model_options=True)
).execute(x, timestep, model_options, seed)
def predict_noise(self, x, timestep, model_options={}, seed=None):
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)