Merge branch 'master' into fix/api-nodes/vidu-pricing

This commit is contained in:
Alexander Piskun
2026-02-17 07:28:13 +02:00
committed by GitHub
12 changed files with 146 additions and 61 deletions

2
.gitignore vendored
View File

@@ -11,7 +11,7 @@ extra_model_paths.yaml
/.vs
.vscode/
.idea/
venv/
venv*/
.venv/
/web/extensions/*
!/web/extensions/logging.js.example

View File

@@ -152,6 +152,7 @@ class Chroma(nn.Module):
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {})
# running on sequences img
@@ -228,6 +229,7 @@ class Chroma(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if i not in self.skip_dit:

View File

@@ -196,6 +196,9 @@ class DoubleStreamBlock(nn.Module):
else:
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
@@ -224,6 +227,12 @@ class DoubleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
if "attn1_output_patch" in transformer_patches:
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
patch = transformer_patches["attn1_output_patch"]
for p in patch:
attn = p(attn, extra_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
@@ -303,6 +312,9 @@ class SingleStreamBlock(nn.Module):
else:
mod = vec
transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@@ -312,6 +324,12 @@ class SingleStreamBlock(nn.Module):
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"]
for p in patch:
attn = p(attn, extra_options)
# compute activation in mlp stream, cat again and run second linear layer
if self.yak_mlp:
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]

View File

@@ -142,6 +142,7 @@ class Flux(nn.Module):
attn_mask: Tensor = None,
) -> Tensor:
transformer_options = transformer_options.copy()
patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
@@ -231,6 +232,7 @@ class Flux(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:

View File

@@ -304,6 +304,7 @@ class HunyuanVideo(nn.Module):
control=None,
transformer_options={},
) -> Tensor:
transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {})
initial_shape = list(img.shape)
@@ -416,6 +417,7 @@ class HunyuanVideo(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:

View File

@@ -406,13 +406,16 @@ class ModelPatcher:
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)
def disable_model_cfg1_optimization(self):
self.model_options["disable_cfg1_optimization"] = True
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True
self.disable_model_cfg1_optimization()
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)

View File

@@ -21,7 +21,6 @@ import logging
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
import comfy.float
import comfy.rmsnorm
import json
import comfy.memory_management
import comfy.pinned_memory
@@ -463,7 +462,7 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
class RMSNorm(torch.nn.RMSNorm, CastWeightBiasOp):
def reset_parameters(self):
self.bias = None
return None
@@ -475,8 +474,7 @@ class disable_weight_init:
weight = None
bias = None
offload_stream = None
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x

View File

@@ -1,57 +1,10 @@
import torch
import comfy.model_management
import numbers
import logging
RMSNorm = None
try:
rms_norm_torch = torch.nn.functional.rms_norm
RMSNorm = torch.nn.RMSNorm
except:
rms_norm_torch = None
logging.warning("Please update pytorch to use native RMSNorm")
RMSNorm = torch.nn.RMSNorm
def rms_norm(x, weight=None, eps=1e-6):
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
if weight is None:
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
else:
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
if weight is None:
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
else:
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
if weight is None:
return r
else:
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
if RMSNorm is None:
class RMSNorm(torch.nn.Module):
def __init__(
self,
normalized_shape,
eps=1e-6,
elementwise_affine=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = torch.nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
else:
self.register_parameter("weight", None)
self.bias = None
def forward(self, x):
return rms_norm(x, self.weight, self.eps)
return torch.nn.functional.rms_norm(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)

View File

@@ -75,6 +75,12 @@ class NumberDisplay(str, Enum):
slider = "slider"
class ControlAfterGenerate(str, Enum):
fixed = "fixed"
increment = "increment"
decrement = "decrement"
randomize = "randomize"
class _ComfyType(ABC):
Type = Any
io_type: str = None
@@ -263,7 +269,7 @@ class Int(ComfyTypeIO):
class Input(WidgetInput):
'''Integer input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool | ControlAfterGenerate=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min
@@ -345,7 +351,7 @@ class Combo(ComfyTypeIO):
tooltip: str=None,
lazy: bool=None,
default: str | int | Enum = None,
control_after_generate: bool=None,
control_after_generate: bool | ControlAfterGenerate=None,
upload: UploadType=None,
image_folder: FolderType=None,
remote: RemoteOptions=None,
@@ -389,7 +395,7 @@ class MultiCombo(ComfyTypeI):
Type = list[str]
class Input(Combo.Input):
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool | ControlAfterGenerate=None,
socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced)
self.multiselect = True
@@ -2097,6 +2103,7 @@ __all__ = [
"UploadType",
"RemoteOptions",
"NumberDisplay",
"ControlAfterGenerate",
"comfytype",
"Custom",

99
comfy_extras/nodes_nag.py Normal file
View File

@@ -0,0 +1,99 @@
import torch
from comfy_api.latest import ComfyExtension, io
from typing_extensions import override
class NAGuidance(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="NAGuidance",
display_name="Normalized Attention Guidance",
description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.",
category="",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to apply NAG to."),
io.Float.Input("nag_scale", min=0.0, default=5.0, max=50.0, step=0.1, tooltip="The guidance scale factor. Higher values push further from the negative prompt."),
io.Float.Input("nag_alpha", min=0.0, default=0.5, max=1.0, step=0.01, tooltip="Blending factor for the normalized attention. 1.0 is full replacement, 0.0 is no effect."),
io.Float.Input("nag_tau", min=1.0, default=1.5, max=10.0, step=0.01),
# io.Float.Input("start_percent", min=0.0, default=0.0, max=1.0, step=0.01, tooltip="The relative sampling step to begin applying NAG."),
# io.Float.Input("end_percent", min=0.0, default=1.0, max=1.0, step=0.01, tooltip="The relative sampling step to stop applying NAG."),
],
outputs=[
io.Model.Output(tooltip="The patched model with NAG enabled."),
],
)
@classmethod
def execute(cls, model: io.Model.Type, nag_scale: float, nag_alpha: float, nag_tau: float) -> io.NodeOutput:
m = model.clone()
# sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_percent)
# sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_percent)
def nag_attention_output_patch(out, extra_options):
cond_or_uncond = extra_options.get("cond_or_uncond", None)
if cond_or_uncond is None:
return out
if not (1 in cond_or_uncond and 0 in cond_or_uncond):
return out
# sigma = extra_options.get("sigmas", None)
# if sigma is not None and len(sigma) > 0:
# sigma = sigma[0].item()
# if sigma > sigma_start or sigma < sigma_end:
# return out
img_slice = extra_options.get("img_slice", None)
if img_slice is not None:
orig_out = out
out = out[:, img_slice[0]:img_slice[1]] # only apply on img part
batch_size = out.shape[0]
half_size = batch_size // len(cond_or_uncond)
ind_neg = cond_or_uncond.index(1)
ind_pos = cond_or_uncond.index(0)
z_pos = out[half_size * ind_pos:half_size * (ind_pos + 1)]
z_neg = out[half_size * ind_neg:half_size * (ind_neg + 1)]
guided = z_pos * nag_scale - z_neg * (nag_scale - 1.0)
eps = 1e-6
norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True).clamp_min(eps)
norm_guided = torch.norm(guided, p=1, dim=-1, keepdim=True).clamp_min(eps)
ratio = norm_guided / norm_pos
scale_factor = torch.minimum(ratio, torch.full_like(ratio, nag_tau)) / ratio
guided_normalized = guided * scale_factor
z_final = guided_normalized * nag_alpha + z_pos * (1.0 - nag_alpha)
if img_slice is not None:
orig_out[half_size * ind_neg:half_size * (ind_neg + 1), img_slice[0]:img_slice[1]] = z_final
orig_out[half_size * ind_pos:half_size * (ind_pos + 1), img_slice[0]:img_slice[1]] = z_final
return orig_out
else:
out[half_size * ind_pos:half_size * (ind_pos + 1)] = z_final
return out
m.set_model_attn1_output_patch(nag_attention_output_patch)
m.disable_model_cfg1_optimization()
return io.NodeOutput(m)
class NagExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
NAGuidance,
]
async def comfy_entrypoint() -> NagExtension:
return NagExtension()

View File

@@ -2437,6 +2437,7 @@ async def init_builtin_extra_nodes():
"nodes_color.py",
"nodes_toolkit.py",
"nodes_replacements.py",
"nodes_nag.py",
]
import_failed = []

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.38.14
comfyui-workflow-templates==0.8.38
comfyui-workflow-templates==0.8.42
comfyui-embedded-docs==0.4.1
torch
torchsde