Compare commits

..

18 Commits

Author SHA1 Message Date
comfyanonymous
88e6370527 Remove workaround for old pytorch. (#12480) 2026-02-15 20:43:53 -05:00
rattus
c0370044cd MPDynamic: force load flux img_in weight (Fixes flux1 canny+depth lora crash) (#12446)
* lora: add weight shape calculations.

This lets the loader know if a lora will change the shape of a weight
so it can take appropriate action.

* MPDynamic: force load flux img_in weight

This weight is a bit special, in that the lora changes its geometry.
This is rather unique, not handled by existing estimate and doesn't
work for either offloading or dynamic_vram.

Fix for dynamic_vram as a special case. Ideally we can fully precalculate
these lora geometry changes at load time, but just get these models
working first.
2026-02-15 20:30:09 -05:00
rattus
ecd2a19661 Fix lora Extraction in offload conditions (+ dynamic_vram mode) (#12479)
* lora_extract: Add a trange

If you bite off more than your GPU can chew, this kinda just hangs.
Give a rough indication of progress counting the weights in a trange.

* lora_extract: Support on-the-fly patching

Use the on-the-fly approach from the regular model saving logic for
lora extraction too. Switch off force_cast_weights accordingly.

This gets extraction working in dynamic vram while also supporting
extraction on GPU offloaded.
2026-02-15 20:28:51 -05:00
Alexander Piskun
2c1d06a4e3 feat(api-nodes): add Bria RMBG nodes (#12465)
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-15 17:22:30 -08:00
Alexander Piskun
e2c71ceb00 feat(api-nodes-Tencent): add ModelTo3DUV, 3DTextureEdit, 3DParts nodes (#12428)
* feat(api-nodes-Tencent): add ModelTo3DUV, 3DTextureEdit, 3DParts nodes

* add image output to TencentModelTo3DUV node

* commented out two nodes

* added rate_limit check to other hunyuan3d nodes
2026-02-15 05:33:18 -08:00
Jedrzej Kosinski
596ed68691 Node Replacement API (#12014) 2026-02-15 02:12:30 -08:00
Alexander Piskun
ce4a1ab48d chore(api-nodes): remove "gpt-4o" model (#12467) 2026-02-15 01:31:59 -08:00
comfyanonymous
e1ede29d82 Remove unsafe pickle loading code that was used on pytorch older than 2.4 (#12473)
ComfyUI hasn't started on pytorch 2.4 since last month.
2026-02-14 22:53:52 -05:00
Christian Byrne
df1e5e8514 Update frontend package to 1.38.14 (#12469) 2026-02-14 11:01:10 -08:00
krigeta
dc9822b7df Add working Qwen 2512 ControlNet (Fun ControlNet) support (#12359) 2026-02-13 22:23:52 -05:00
comfyanonymous
712efb466b Add left padding to LTXAV text encoder. (#12456) 2026-02-13 21:56:54 -05:00
comfyanonymous
726af73867 Fix some custom nodes. (#12455) 2026-02-13 20:21:10 -05:00
comfyanonymous
831351a29e Support generating attention masks for left padded text encoders. (#12454) 2026-02-13 20:15:23 -05:00
comfyanonymous
e1add563f9 Use torch RMSNorm for flux models and refactor hunyuan video code. (#12432) 2026-02-13 15:35:13 -05:00
rattus
8902907d7a dynamic_vram: Training fixes (#12442) 2026-02-13 15:29:37 -05:00
comfyanonymous
e03fe8b591 Update command to install AMD stable linux pytorch. (#12437) 2026-02-12 23:29:12 -05:00
rattus
ae79e33345 llama: use a more efficient rope implementation (#12434)
Get rid of the cat and unary negation and inplace add-cmul the two
halves of the rope. Precompute -sin once at the start of the model
rather than every transformer block.

This is slightly faster on both GPU and CPU bound setups.
2026-02-12 19:56:42 -05:00
rattus
117e214354 ModelPatcherDynamic: force load non leaf weights (#12433)
The current behaviour of the default ModelPatcher is to .to a model
only if its fully loaded, which is how random non-leaf weights get
loaded in non-LowVRAM conditions.

The however means they never get loaded in dynamic_vram. In the
dynamic_vram case, force load them to the GPU.
2026-02-12 19:51:50 -05:00
45 changed files with 1323 additions and 673 deletions

View File

@@ -227,7 +227,7 @@ Put your VAE in: models/vae
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:

105
app/node_replace_manager.py Normal file
View File

@@ -0,0 +1,105 @@
from __future__ import annotations
from aiohttp import web
from typing import TYPE_CHECKING, TypedDict
if TYPE_CHECKING:
from comfy_api.latest._io_public import NodeReplace
from comfy_execution.graph_utils import is_link
import nodes
class NodeStruct(TypedDict):
inputs: dict[str, str | int | float | bool | tuple[str, int]]
class_type: str
_meta: dict[str, str]
def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
new_node_struct = node_struct.copy()
if empty_inputs:
new_node_struct["inputs"] = {}
else:
new_node_struct["inputs"] = node_struct["inputs"].copy()
new_node_struct["_meta"] = node_struct["_meta"].copy()
return new_node_struct
class NodeReplaceManager:
"""Manages node replacement registrations."""
def __init__(self):
self._replacements: dict[str, list[NodeReplace]] = {}
def register(self, node_replace: NodeReplace):
"""Register a node replacement mapping."""
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
"""Get replacements for an old node ID."""
return self._replacements.get(old_node_id)
def has_replacement(self, old_node_id: str) -> bool:
"""Check if a replacement exists for an old node ID."""
return old_node_id in self._replacements
def apply_replacements(self, prompt: dict[str, NodeStruct]):
connections: dict[str, list[tuple[str, str, int]]] = {}
need_replacement: set[str] = set()
for node_number, node_struct in prompt.items():
class_type = node_struct["class_type"]
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
need_replacement.add(node_number)
# keep track of connections
for input_id, input_value in node_struct["inputs"].items():
if is_link(input_value):
conn_number = input_value[0]
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
for node_number in need_replacement:
node_struct = prompt[node_number]
class_type = node_struct["class_type"]
replacements = self.get_replacement(class_type)
if replacements is None:
continue
# just use the first replacement
replacement = replacements[0]
new_node_id = replacement.new_node_id
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
continue
# first, replace node id (class_type)
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
new_node_struct["class_type"] = new_node_id
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
# second, replace inputs
if replacement.input_mapping is not None:
for input_map in replacement.input_mapping:
if "set_value" in input_map:
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
elif "old_id" in input_map:
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
# finalize input replacement
prompt[node_number] = new_node_struct
# third, replace outputs
if replacement.output_mapping is not None:
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
if node_number in connections:
for conns in connections[node_number]:
conn_node_number, conn_input_id, old_output_idx = conns
for output_map in replacement.output_mapping:
if output_map["old_idx"] == old_output_idx:
new_output_idx = output_map["new_idx"]
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
previous_input[1] = new_output_idx
def as_dict(self):
"""Serialize all replacements to dict."""
return {
k: [v.as_dict() for v in v_list]
for k, v_list in self._replacements.items()
}
def add_routes(self, routes):
@routes.get("/node_replacements")
async def get_node_replacements(request):
return web.json_response(self.as_dict())

View File

@@ -1,13 +0,0 @@
import pickle
load = pickle.load
class Empty:
pass
class Unpickler(pickle.Unpickler):
def find_class(self, module, name):
#TODO: safe unpickle
if module.startswith("pytorch_lightning"):
return Empty
return super().find_class(module, name)

View File

@@ -297,6 +297,30 @@ class ControlNet(ControlBase):
self.model_sampling_current = None
super().cleanup()
class QwenFunControlNet(ControlNet):
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
# Fun checkpoints are more sensitive to high strengths in the generic
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
# unchanged while >1 grows more gently.
original_strength = self.strength
self.strength = math.sqrt(max(self.strength, 0.0))
try:
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
finally:
self.strength = original_strength
def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
self.set_extra_arg("base_model", model.diffusion_model)
def copy(self):
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c
class ControlLoraOps:
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
@@ -560,6 +584,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
sd = model_config.process_unet_state_dict(sd)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
@@ -605,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet_qwen_fun(sd, model_options={}):
load_device = comfy.model_management.get_torch_device()
weight_dtype = comfy.utils.weight_dtype(sd)
unet_dtype = model_options.get("dtype", weight_dtype)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
in_features = sd["control_img_in.weight"].shape[1]
inner_dim = sd["control_img_in.weight"].shape[0]
block_weight = sd["control_blocks.0.attn.to_q.weight"]
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
control_in_features=in_features,
inner_dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_control_blocks=5,
main_model_double=60,
injection_layers=(0, 12, 24, 36, 48),
operations=operations,
device=comfy.model_management.unet_offload_device(),
dtype=unet_dtype,
)
model = controlnet_load_state_dict(model, sd)
latent_format = comfy.latent_formats.Wan21()
control = QwenFunControlNet(
model,
compression_ratio=1,
latent_format=latent_format,
# Fun checkpoints already expect their own 33-channel context handling.
# Enabling generic concat_mask injects an extra mask channel at apply-time
# and breaks the intended fallback packing path.
concat_mask=False,
load_device=load_device,
manual_cast_dtype=manual_cast_dtype,
extra_conds=[],
)
return control
def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
@@ -682,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)

View File

@@ -3,7 +3,6 @@ from torch import Tensor, nn
from comfy.ldm.flux.layers import (
MLPEmbedder,
RMSNorm,
ModulationOut,
)
@@ -29,7 +28,7 @@ class Approximator(nn.Module):
super().__init__()
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
@property

View File

@@ -4,8 +4,6 @@ from functools import lru_cache
import torch
from torch import nn
from comfy.ldm.flux.layers import RMSNorm
class NerfEmbedder(nn.Module):
"""
@@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module):
# We now need to generate parameters for 3 matrices.
total_params = 3 * hidden_size_x**2 * mlp_ratio
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
self.mlp_ratio = mlp_ratio
@@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module):
class NerfFinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module):
class NerfFinalLayerConv(nn.Module):
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.conv = operations.Conv2d(
in_channels=hidden_size,
out_channels=out_channels,

View File

@@ -5,9 +5,9 @@ import torch
from torch import Tensor, nn
from .math import attention, rope
import comfy.ops
import comfy.ldm.common_dit
# Fix import for some custom nodes, TODO: delete eventually.
RMSNorm = None
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list):
@@ -87,20 +87,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor):
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
class QKNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
q = self.query_norm(q)
@@ -169,7 +161,7 @@ class SiLUActivation(nn.Module):
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@@ -197,8 +189,6 @@ class DoubleStreamBlock(nn.Module):
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
if self.modulation:
img_mod1, img_mod2 = self.img_mod(vec)
@@ -224,32 +214,17 @@ class DoubleStreamBlock(nn.Module):
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
if self.flipped_img_txt:
q = torch.cat((img_q, txt_q), dim=2)
del img_q, txt_q
k = torch.cat((img_k, txt_k), dim=2)
del img_k, txt_k
v = torch.cat((img_v, txt_v), dim=2)
del img_v, txt_v
# run actual attention
attn = attention(q, k, v,
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# run actual attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# run actual attention
attn = attention(q, k, v,
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)

View File

@@ -16,7 +16,6 @@ from .layers import (
SingleStreamBlock,
timestep_embedding,
Modulation,
RMSNorm
)
@dataclass
@@ -81,7 +80,7 @@ class Flux(nn.Module):
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
if params.txt_norm:
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
else:
self.txt_norm = None

View File

@@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
flipped_img_txt=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -378,14 +377,14 @@ class HunyuanVideo(nn.Module):
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
ids = torch.cat((img_ids, txt_ids), dim=1)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
img_len = img.shape[1]
if txt_mask is not None:
attn_mask_len = img_len + txt.shape[1]
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
attn_mask[:, 0, img_len:] = txt_mask
attn_mask[:, 0, :txt.shape[1]] = txt_mask
else:
attn_mask = None
@@ -413,7 +412,7 @@ class HunyuanVideo(nn.Module):
if add is not None:
img += add
img = torch.cat((img, txt), 1)
img = torch.cat((txt, img), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
@@ -435,9 +434,9 @@ class HunyuanVideo(nn.Module):
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, : img_len] += add
img[:, txt.shape[1]: img_len + txt.shape[1]] += add
img = img[:, : img_len]
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
if ref_latent is not None:
img = img[:, ref_latent.shape[1]:]

View File

@@ -102,19 +102,7 @@ class VideoConv3d(nn.Module):
return self.conv(x)
def interpolate_up(x, scale_factor):
try:
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
except: #operation not implemented for bf16
orig_shape = list(x.shape)
out_shape = orig_shape[:2]
for i in range(len(orig_shape) - 2):
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
return out
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):

View File

@@ -2,6 +2,196 @@ import torch
import math
from .model import QwenImageTransformer2DModel
from .model import QwenImageTransformerBlock
class QwenImageFunControlBlock(QwenImageTransformerBlock):
def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None):
super().__init__(
dim=dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dtype=dtype,
device=device,
operations=operations,
)
self.has_before_proj = has_before_proj
if has_before_proj:
self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
class QwenImageFunControlNetModel(torch.nn.Module):
def __init__(
self,
control_in_features=132,
inner_dim=3072,
num_attention_heads=24,
attention_head_dim=128,
num_control_blocks=5,
main_model_double=60,
injection_layers=(0, 12, 24, 36, 48),
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dtype = dtype
self.main_model_double = main_model_double
self.injection_layers = tuple(injection_layers)
# Keep base hint scaling at 1.0 so user-facing strength behaves similarly
# to the reference Gen2/VideoX implementation around strength=1.
self.hint_scale = 1.0
self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype)
self.control_blocks = torch.nn.ModuleList([])
for i in range(num_control_blocks):
self.control_blocks.append(
QwenImageFunControlBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
has_before_proj=(i == 0),
dtype=dtype,
device=device,
operations=operations,
)
)
def _process_hint_tokens(self, hint):
if hint is None:
return None
if hint.ndim == 4:
hint = hint.unsqueeze(2)
# Fun checkpoints are trained with 33 latent channels before 2x2 packing:
# [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features.
# Default behavior (no inpaint input in stock Apply ControlNet) should use
# zeros for mask/inpaint branches, matching VideoX fallback semantics.
expected_c = self.control_img_in.weight.shape[1] // 4
if hint.shape[1] == 16 and expected_c == 33:
zeros_mask = torch.zeros_like(hint[:, :1])
zeros_inpaint = torch.zeros_like(hint)
hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1)
bs, c, t, h, w = hint.shape
hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(
orig_shape[0],
orig_shape[1],
orig_shape[-3],
orig_shape[-2] // 2,
2,
orig_shape[-1] // 2,
2,
)
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
hidden_states = hidden_states.reshape(
bs,
t * ((h + 1) // 2) * ((w + 1) // 2),
c * 4,
)
expected_in = self.control_img_in.weight.shape[1]
cur_in = hidden_states.shape[-1]
if cur_in < expected_in:
pad = torch.zeros(
(hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
hidden_states = torch.cat([hidden_states, pad], dim=-1)
elif cur_in > expected_in:
hidden_states = hidden_states[:, :, :expected_in]
return hidden_states
def forward(
self,
x,
timesteps,
context,
attention_mask=None,
guidance: torch.Tensor = None,
hint=None,
transformer_options={},
base_model=None,
**kwargs,
):
if base_model is None:
raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.")
encoder_hidden_states_mask = attention_mask
# Keep attention mask disabled inside Fun control blocks to mirror
# VideoX behavior (they rely on seq lengths for RoPE, not masked attention).
encoder_hidden_states_mask = None
hidden_states, img_ids, _ = base_model.process_img(x)
hint_tokens = self._process_hint_tokens(hint)
if hint_tokens is None:
raise RuntimeError("Qwen Fun ControlNet requires a control hint image.")
if hint_tokens.shape[1] != hidden_states.shape[1]:
max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1])
hint_tokens = hint_tokens[:, :max_tokens]
hidden_states = hidden_states[:, :max_tokens]
img_ids = img_ids[:, :max_tokens]
txt_start = round(
max(
((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
)
)
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous()
hidden_states = base_model.img_in(hidden_states)
encoder_hidden_states = base_model.txt_norm(context)
encoder_hidden_states = base_model.txt_in(encoder_hidden_states)
if guidance is not None:
guidance = guidance * 1000
temb = (
base_model.time_text_embed(timesteps, hidden_states)
if guidance is None
else base_model.time_text_embed(timesteps, guidance, hidden_states)
)
c = self.control_img_in(hint_tokens)
for i, block in enumerate(self.control_blocks):
if i == 0:
c_in = block.before_proj(c) + hidden_states
all_c = []
else:
all_c = list(torch.unbind(c, dim=0))
c_in = all_c.pop(-1)
encoder_hidden_states, c_out = block(
hidden_states=c_in,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
c_skip = block.after_proj(c_out) * self.hint_scale
all_c += [c_skip, c_out]
c = torch.stack(all_c, dim=0)
hints = torch.unbind(c, dim=0)[:-1]
controlnet_block_samples = [None] * self.main_model_double
for local_idx, base_idx in enumerate(self.injection_layers):
if local_idx < len(hints) and base_idx < len(controlnet_block_samples):
controlnet_block_samples[base_idx] = hints[local_idx]
return {"input": controlnet_block_samples}
class QwenImageControlNetModel(QwenImageTransformer2DModel):

View File

@@ -374,6 +374,31 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
return padded_tensor
def calculate_shape(patches, weight, key, original_weights=None):
current_shape = weight.shape
for p in patches:
v = p[1]
offset = p[3]
# Offsets restore the old shape; lists force a diff without metadata
if offset is not None or isinstance(v, list):
continue
if isinstance(v, weight_adapter.WeightAdapterBase):
adapter_shape = v.calculate_shape(key)
if adapter_shape is not None:
current_shape = adapter_shape
continue
# Standard diff logic with padding
if len(v) == 2:
patch_type, patch_data = v[0], v[1]
if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']:
current_shape = patch_data[0].shape
return current_shape
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
for p in patches:
strength = p[0]

View File

@@ -5,7 +5,7 @@ import comfy.utils
def convert_lora_bfl_control(sd): #BFL loras for Flux
sd_out = {}
for k in sd:
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight"))
sd_out[k_to] = sd[k]
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])

View File

@@ -19,6 +19,12 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1
return count
def any_suffix_in(keys, prefix, main, suffix_list=[]):
for x in suffix_list:
if "{}{}{}".format(prefix, main, x) in keys:
return True
return False
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
context_dim = None
use_linear_in_transformer = False
@@ -186,7 +192,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["meanflow_sum"] = False
return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
dit_config = {}
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
dit_config["image_model"] = "flux2"
@@ -241,7 +247,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma
dit_config["image_model"] = "chroma"
dit_config["in_channels"] = 64
dit_config["out_channels"] = 64
@@ -249,7 +256,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["out_dim"] = 3072
dit_config["hidden_dim"] = 5120
dit_config["n_layers"] = 5
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance
dit_config["image_model"] = "chroma_radiance"
dit_config["in_channels"] = 3
dit_config["out_channels"] = 3
@@ -259,7 +267,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_depth"] = 4
dit_config["nerf_max_freqs"] = 8
dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
@@ -268,7 +276,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
dit_config["txt_ids_dims"] = [1, 2]

View File

@@ -679,18 +679,19 @@ class ModelPatcher:
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self, prio_comfy_cast_weights=False):
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
loading = []
for n, m in self.model.named_modules():
params = []
skip = False
for name, param in m.named_parameters(recurse=False):
params.append(name)
default = False
params = { name: param for name, param in m.named_parameters(recurse=False) }
for name, param in m.named_parameters(recurse=True):
if name not in params:
skip = True # skip random weights in non leaf modules
default = True # default random weights in non leaf modules
break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
if default and default_device is not None:
for param in params.values():
param.data = param.data.to(device=default_device)
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
@@ -1495,7 +1496,7 @@ class ModelPatcherDynamic(ModelPatcher):
#with pin and unpin syncrhonization which can be expensive for small weights
#with a high layer rate (e.g. autoregressive LLMs).
#prioritize the non-comfy weights (note the order reverse).
loading = self._load_list(prio_comfy_cast_weights=True)
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
loading.sort(reverse=True)
for x in loading:
@@ -1513,8 +1514,10 @@ class ModelPatcherDynamic(ModelPatcher):
weight, _, _ = get_key_weight(self.model, key)
if weight is None:
return 0
return (False, 0)
if key in self.patches:
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
return (True, 0)
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
num_patches += 1
else:
@@ -1528,7 +1531,13 @@ class ModelPatcherDynamic(ModelPatcher):
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
weight._model_dtype = model_dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
return comfy.memory_management.vram_aligned_size(geometry)
return (False, comfy.memory_management.vram_aligned_size(geometry))
def force_load_param(self, param_key, device_to):
key = key_param_name_to_key(n, param_key)
if key in self.backup:
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
self.patch_weight_to_device(key, device_to=device_to)
if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
@@ -1536,13 +1545,19 @@ class ModelPatcherDynamic(ModelPatcher):
m.seed_key = n
set_dirty(m, dirty)
v_weight_size = 0
v_weight_size += setup_param(self, m, n, "weight")
v_weight_size += setup_param(self, m, n, "bias")
force_load, v_weight_size = setup_param(self, m, n, "weight")
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
force_load = force_load or force_load_bias
v_weight_size += v_weight_bias
if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size)
allocated_size += v_weight_size
if force_load:
logging.info(f"Module {n} has resizing Lora - force loading")
force_load_param(self, "weight", device_to)
force_load_param(self, "bias", device_to)
else:
if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size)
allocated_size += v_weight_size
else:
for param in params:
@@ -1560,6 +1575,8 @@ class ModelPatcherDynamic(ModelPatcher):
allocated_size += weight_size
vbar.set_watermark_limit(allocated_size)
move_weight_functions(m, device_to)
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
self.model.device = device_to
@@ -1579,7 +1596,7 @@ class ModelPatcherDynamic(ModelPatcher):
return 0 if vbar is None else vbar.free_memory(memory_to_free)
def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(prio_comfy_cast_weights=True)
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
for x in loading:
_, _, _, _, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
@@ -1600,6 +1617,13 @@ class ModelPatcherDynamic(ModelPatcher):
if unpatch_weights:
self.partially_unload_ram(1e32)
self.partially_unload(None, 1e32)
for m in self.model.modules():
move_weight_functions(m, device_to)
keys = list(self.backup.keys())
for k in keys:
bk = self.backup[k]
comfy.utils.set_attr_param(self.model, k, bk.weight)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above

View File

@@ -171,8 +171,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def process_tokens(self, tokens, device):
end_token = self.special_tokens.get("end", None)
pad_token = self.special_tokens.get("pad", -1)
if end_token is None:
cmp_token = self.special_tokens.get("pad", -1)
cmp_token = pad_token
else:
cmp_token = end_token
@@ -186,15 +187,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
other_embeds = []
eos = False
index = 0
left_pad = False
for y in x:
if isinstance(y, numbers.Integral):
if eos:
token = int(y)
if index == 0 and token == pad_token:
left_pad = True
if eos or (left_pad and token == pad_token):
attention_mask.append(0)
else:
attention_mask.append(1)
token = int(y)
left_pad = False
tokens_temp += [token]
if not eos and token == cmp_token:
if not eos and token == cmp_token and not left_pad:
if end_token is None:
attention_mask[-1] = 0
eos = True

View File

@@ -710,6 +710,15 @@ class Flux(supported_models_base.BASE):
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
if key_out.endswith("_norm.scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
@@ -898,11 +907,13 @@ class HunyuanVideo(supported_models_base.BASE):
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight")
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight")
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
if key_out.endswith(".scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
@@ -1264,6 +1275,15 @@ class Hunyuan3Dv2(supported_models_base.BASE):
latent_format = latent_formats.Hunyuan3Dv2
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
if key_out.endswith(".scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
@@ -1341,6 +1361,14 @@ class Chroma(supported_models_base.BASE):
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
if key_out.endswith(".scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Chroma(self, device=device)

View File

@@ -10,7 +10,6 @@ import comfy.utils
def sample_manual_loop_no_classes(
model,
ids=None,
paddings=[],
execution_dtype=None,
cfg_scale: float = 2.0,
temperature: float = 0.85,
@@ -36,9 +35,6 @@ def sample_manual_loop_no_classes(
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
embeds_batch = embeds.shape[0]
for i, t in enumerate(paddings):
attention_mask[i, :t] = 0
attention_mask[i, t:] = 1
output_audio_codes = []
past_key_values = []
@@ -135,13 +131,11 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
pos_pad = (len(negative) - len(positive))
positive = [model.special_tokens["pad"]] * pos_pad + positive
paddings = [pos_pad, neg_pad]
ids = [positive, negative]
else:
paddings = []
ids = [positive]
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
return sample_manual_loop_no_classes(model, ids, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):

View File

@@ -355,13 +355,6 @@ class RMSNorm(nn.Module):
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
if not isinstance(theta, list):
theta = [theta]
@@ -390,20 +383,30 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
out.append((cos, sin))
sin_split = sin.shape[-1] // 2
out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
if len(out) == 1:
return out[0]
return out
def apply_rope(xq, xk, freqs_cis):
org_dtype = xq.dtype
cos = freqs_cis[0]
sin = freqs_cis[1]
q_embed = (xq * cos) + (rotate_half(xq) * sin)
k_embed = (xk * cos) + (rotate_half(xk) * sin)
nsin = freqs_cis[2]
q_embed = (xq * cos)
q_split = q_embed.shape[-1] // 2
q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin)
q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin)
k_embed = (xk * cos)
k_split = k_embed.shape[-1] // 2
k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin)
k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin)
return q_embed.to(org_dtype), k_embed.to(org_dtype)

View File

@@ -25,7 +25,7 @@ def ltxv_te(*args, **kwargs):
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
@@ -97,6 +97,7 @@ class LTXAVTEModel(torch.nn.Module):
token_weight_pairs = token_weight_pairs["gemma3_12b"]
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
out = out[:, :, -torch.sum(extra["attention_mask"]).item():]
out_device = out.device
if comfy.model_management.should_use_bf16(self.execution_device):
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
@@ -138,6 +139,7 @@ class LTXAVTEModel(torch.nn.Module):
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
num_tokens = max(num_tokens, 64)
return num_tokens * constant * 1024 * 1024
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):

View File

@@ -20,7 +20,7 @@
import torch
import math
import struct
import comfy.checkpoint_pickle
import comfy.memory_management
import safetensors.torch
import numpy as np
from PIL import Image
@@ -38,26 +38,26 @@ import warnings
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
ALWAYS_SAFE_LOAD = False
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
if True: # ckpt/pt file whitelist for safe loading of old sd files
class ModelCheckpoint:
pass
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
def scalar(*args, **kwargs):
from numpy.core.multiarray import scalar as sc
return sc(*args, **kwargs)
return None
scalar.__module__ = "numpy.core.multiarray"
from numpy import dtype
from numpy.dtypes import Float64DType
from _codecs import encode
def encode(*args, **kwargs): # no longer necessary on newer torch
return None
encode.__module__ = "_codecs"
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.")
else:
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
# Current as of safetensors 0.7.0
_TYPES = {
@@ -140,11 +140,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if MMAP_TORCH_FILES:
torch_args["mmap"] = True
if safe_load or ALWAYS_SAFE_LOAD:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else:
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
@@ -675,10 +672,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"ff_context.linear_in.bias": "txt_mlp.0.bias",
"ff_context.linear_out.weight": "txt_mlp.2.weight",
"ff_context.linear_out.bias": "txt_mlp.2.bias",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
"attn.norm_q.weight": "img_attn.norm.query_norm.weight",
"attn.norm_k.weight": "img_attn.norm.key_norm.weight",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight",
}
for k in block_map:
@@ -701,8 +698,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"norm.linear.bias": "modulation.lin.bias",
"proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias",
"attn.norm_q.weight": "norm.query_norm.scale",
"attn.norm_k.weight": "norm.key_norm.scale",
"attn.norm_q.weight": "norm.query_norm.weight",
"attn.norm_k.weight": "norm.key_norm.weight",
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
"attn.to_out.weight": "linear2.weight", # Flux 2
}

View File

@@ -49,6 +49,12 @@ class WeightAdapterBase:
"""
raise NotImplementedError
def calculate_shape(
self,
key
):
return None
def calculate_weight(
self,
weight,

View File

@@ -214,6 +214,13 @@ class LoRAAdapter(WeightAdapterBase):
else:
return None
def calculate_shape(
self,
key
):
reshape = self.weights[5]
return tuple(reshape) if reshape is not None else None
def calculate_weight(
self,
weight,

View File

@@ -14,6 +14,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
"node_replacements": True,
}

View File

@@ -21,6 +21,17 @@ class ComfyAPI_latest(ComfyAPIBase):
VERSION = "latest"
STABLE = False
def __init__(self):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()
class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: io.NodeReplace) -> None:
"""Register a node replacement mapping."""
from server import PromptServer
PromptServer.instance.node_replace_manager.register(node_replace)
class Execution(ProxiedSingleton):
async def set_progress(
self,
@@ -73,8 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display,
)
execution: Execution
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""

View File

@@ -1203,89 +1203,6 @@ class Color(ComfyTypeIO):
def as_dict(self):
return super().as_dict()
@comfytype(io_type="COLOR_CORRECT")
class ColorCorrect(ComfyTypeIO):
Type = dict
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: dict=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = {
"temperature": 0,
"hue": 0,
"brightness": 0,
"contrast": 0,
"saturation": 0,
"gamma": 1.0
}
def as_dict(self):
return super().as_dict()
@comfytype(io_type="COLOR_BALANCE")
class ColorBalance(ComfyTypeIO):
Type = dict
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: dict=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = {
"shadows_red": 0,
"shadows_green": 0,
"shadows_blue": 0,
"midtones_red": 0,
"midtones_green": 0,
"midtones_blue": 0,
"highlights_red": 0,
"highlights_green": 0,
"highlights_blue": 0
}
def as_dict(self):
return super().as_dict()
@comfytype(io_type="COLOR_CURVES")
class ColorCurves(ComfyTypeIO):
Type = dict
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: dict=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = {
"rgb": [[0, 0], [1, 1]],
"red": [[0, 0], [1, 1]],
"green": [[0, 0], [1, 1]],
"blue": [[0, 0], [1, 1]]
}
def as_dict(self):
return super().as_dict()
@comfytype(io_type="BOUNDING_BOX")
class BoundingBox(ComfyTypeIO):
Type = dict
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: dict=None, component: str=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless)
self.component = component
if default is None:
self.default = {"x": 0, "y": 0, "width": 512, "height": 512}
def as_dict(self):
d = super().as_dict()
if self.component:
d["component"] = self.component
return d
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func
@@ -2113,6 +2030,68 @@ class _UIOutput(ABC):
...
class InputMapOldId(TypedDict):
"""Map an old node input to a new node input by ID."""
new_id: str
old_id: str
class InputMapSetValue(TypedDict):
"""Set a specific value for a new node input."""
new_id: str
set_value: Any
InputMap = InputMapOldId | InputMapSetValue
"""
Input mapping for node replacement. Type is inferred by dictionary keys:
- {"new_id": str, "old_id": str} - maps old input to new input
- {"new_id": str, "set_value": Any} - sets a specific value for new input
"""
class OutputMap(TypedDict):
"""Map outputs of node replacement via indexes."""
new_idx: int
old_idx: int
class NodeReplace:
"""
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
Also supports assigning specific values to the input widgets of the new node.
Args:
new_node_id: The class name of the new replacement node.
old_node_id: The class name of the deprecated node.
old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot
connected. The workflow JSON stores widget values by their relative position index,
not by ID. This list maps those positional indexes to input IDs, enabling the
replacement system to correctly identify widget values during node migration.
input_mapping: List of input mappings from old node to new node.
output_mapping: List of output mappings from old node to new node.
"""
def __init__(self,
new_node_id: str,
old_node_id: str,
old_widget_ids: list[str] | None=None,
input_mapping: list[InputMap] | None=None,
output_mapping: list[OutputMap] | None=None,
):
self.new_node_id = new_node_id
self.old_node_id = old_node_id
self.old_widget_ids = old_widget_ids
self.input_mapping = input_mapping
self.output_mapping = output_mapping
def as_dict(self):
"""Create serializable representation of the node replacement."""
return {
"new_node_id": self.new_node_id,
"old_node_id": self.old_node_id,
"old_widget_ids": self.old_widget_ids,
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
}
__all__ = [
"FolderType",
"UploadType",
@@ -2204,8 +2183,5 @@ __all__ = [
"ImageCompare",
"PriceBadgeDepends",
"PriceBadge",
"BoundingBox",
"ColorCorrect",
"ColorBalance",
"ColorCurves"
"NodeReplace",
]

View File

@@ -45,17 +45,55 @@ class BriaEditImageRequest(BaseModel):
)
class BriaRemoveBackgroundRequest(BaseModel):
image: str = Field(...)
sync: bool = Field(False)
visual_input_content_moderation: bool = Field(
False, description="If true, returns 422 on input image moderation failure."
)
visual_output_content_moderation: bool = Field(
False, description="If true, returns 422 on visual output moderation failure."
)
seed: int = Field(...)
class BriaStatusResponse(BaseModel):
request_id: str = Field(...)
status_url: str = Field(...)
warning: str | None = Field(None)
class BriaResult(BaseModel):
class BriaRemoveBackgroundResult(BaseModel):
image_url: str = Field(...)
class BriaRemoveBackgroundResponse(BaseModel):
status: str = Field(...)
result: BriaRemoveBackgroundResult | None = Field(None)
class BriaImageEditResult(BaseModel):
structured_prompt: str = Field(...)
image_url: str = Field(...)
class BriaResponse(BaseModel):
class BriaImageEditResponse(BaseModel):
status: str = Field(...)
result: BriaResult | None = Field(None)
result: BriaImageEditResult | None = Field(None)
class BriaRemoveVideoBackgroundRequest(BaseModel):
video: str = Field(...)
background_color: str = Field(default="transparent", description="Background color for the output video.")
output_container_and_codec: str = Field(...)
preserve_audio: bool = Field(True)
seed: int = Field(...)
class BriaRemoveVideoBackgroundResult(BaseModel):
video_url: str = Field(...)
class BriaRemoveVideoBackgroundResponse(BaseModel):
status: str = Field(...)
result: BriaRemoveVideoBackgroundResult | None = Field(None)

View File

@@ -64,3 +64,23 @@ class To3DProTaskResultResponse(BaseModel):
class To3DProTaskQueryRequest(BaseModel):
JobId: str = Field(...)
class To3DUVFileInput(BaseModel):
Type: str = Field(..., description="File type: GLB, OBJ, or FBX")
Url: str = Field(...)
class To3DUVTaskRequest(BaseModel):
File: To3DUVFileInput = Field(...)
class TextureEditImageInfo(BaseModel):
Url: str = Field(...)
class TextureEditTaskRequest(BaseModel):
File3D: To3DUVFileInput = Field(...)
Image: TextureEditImageInfo | None = Field(None)
Prompt: str | None = Field(None)
EnablePBR: bool | None = Field(None)

View File

@@ -3,7 +3,11 @@ from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bria import (
BriaEditImageRequest,
BriaResponse,
BriaRemoveBackgroundRequest,
BriaRemoveBackgroundResponse,
BriaRemoveVideoBackgroundRequest,
BriaRemoveVideoBackgroundResponse,
BriaImageEditResponse,
BriaStatusResponse,
InputModerationSettings,
)
@@ -11,10 +15,12 @@ from comfy_api_nodes.util import (
ApiEndpoint,
convert_mask_to_image,
download_url_to_image_tensor,
get_number_of_images,
download_url_to_video_output,
poll_op,
sync_op,
upload_images_to_comfyapi,
upload_image_to_comfyapi,
upload_video_to_comfyapi,
validate_video_duration,
)
@@ -73,21 +79,15 @@ class BriaImageEditNode(IO.ComfyNode):
IO.DynamicCombo.Input(
"moderation",
options=[
IO.DynamicCombo.Option("false", []),
IO.DynamicCombo.Option(
"true",
[
IO.Boolean.Input(
"prompt_content_moderation", default=False
),
IO.Boolean.Input(
"visual_input_moderation", default=False
),
IO.Boolean.Input(
"visual_output_moderation", default=True
),
IO.Boolean.Input("prompt_content_moderation", default=False),
IO.Boolean.Input("visual_input_moderation", default=False),
IO.Boolean.Input("visual_output_moderation", default=True),
],
),
IO.DynamicCombo.Option("false", []),
],
tooltip="Moderation settings",
),
@@ -127,50 +127,26 @@ class BriaImageEditNode(IO.ComfyNode):
mask: Input.Image | None = None,
) -> IO.NodeOutput:
if not prompt and not structured_prompt:
raise ValueError(
"One of prompt or structured_prompt is required to be non-empty."
)
if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.")
raise ValueError("One of prompt or structured_prompt is required to be non-empty.")
mask_url = None
if mask is not None:
mask_url = (
await upload_images_to_comfyapi(
cls,
convert_mask_to_image(mask),
max_images=1,
mime_type="image/png",
wait_label="Uploading mask",
)
)[0]
mask_url = await upload_image_to_comfyapi(cls, convert_mask_to_image(mask), wait_label="Uploading mask")
response = await sync_op(
cls,
ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"),
data=BriaEditImageRequest(
instruction=prompt if prompt else None,
structured_instruction=structured_prompt if structured_prompt else None,
images=await upload_images_to_comfyapi(
cls,
image,
max_images=1,
mime_type="image/png",
wait_label="Uploading image",
),
images=[await upload_image_to_comfyapi(cls, image, wait_label="Uploading image")],
mask=mask_url,
negative_prompt=negative_prompt if negative_prompt else None,
guidance_scale=guidance_scale,
seed=seed,
model_version=model,
steps_num=steps,
prompt_content_moderation=moderation.get(
"prompt_content_moderation", False
),
visual_input_content_moderation=moderation.get(
"visual_input_moderation", False
),
visual_output_content_moderation=moderation.get(
"visual_output_moderation", False
),
prompt_content_moderation=moderation.get("prompt_content_moderation", False),
visual_input_content_moderation=moderation.get("visual_input_moderation", False),
visual_output_content_moderation=moderation.get("visual_output_moderation", False),
),
response_model=BriaStatusResponse,
)
@@ -178,7 +154,7 @@ class BriaImageEditNode(IO.ComfyNode):
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
response_model=BriaResponse,
response_model=BriaImageEditResponse,
)
return IO.NodeOutput(
await download_url_to_image_tensor(response.result.image_url),
@@ -186,11 +162,167 @@ class BriaImageEditNode(IO.ComfyNode):
)
class BriaRemoveImageBackground(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BriaRemoveImageBackground",
display_name="Bria Remove Image Background",
category="api node/image/Bria",
description="Remove the background from an image using Bria RMBG 2.0.",
inputs=[
IO.Image.Input("image"),
IO.DynamicCombo.Input(
"moderation",
options=[
IO.DynamicCombo.Option("false", []),
IO.DynamicCombo.Option(
"true",
[
IO.Boolean.Input("visual_input_moderation", default=False),
IO.Boolean.Input("visual_output_moderation", default=True),
],
),
],
tooltip="Moderation settings",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.018}""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
moderation: dict,
seed: int,
) -> IO.NodeOutput:
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bria/v2/image/edit/remove_background", method="POST"),
data=BriaRemoveBackgroundRequest(
image=await upload_image_to_comfyapi(cls, image, wait_label="Uploading image"),
sync=False,
visual_input_content_moderation=moderation.get("visual_input_moderation", False),
visual_output_content_moderation=moderation.get("visual_output_moderation", False),
seed=seed,
),
response_model=BriaStatusResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
response_model=BriaRemoveBackgroundResponse,
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result.image_url))
class BriaRemoveVideoBackground(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BriaRemoveVideoBackground",
display_name="Bria Remove Video Background",
category="api node/video/Bria",
description="Remove the background from a video using Bria. ",
inputs=[
IO.Video.Input("video"),
IO.Combo.Input(
"background_color",
options=[
"Black",
"White",
"Gray",
"Red",
"Green",
"Blue",
"Yellow",
"Cyan",
"Magenta",
"Orange",
],
tooltip="Background color for the output video.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
background_color: str,
seed: int,
) -> IO.NodeOutput:
validate_video_duration(video, max_duration=60.0)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"),
data=BriaRemoveVideoBackgroundRequest(
video=await upload_video_to_comfyapi(cls, video),
background_color=background_color,
output_container_and_codec="mp4_h264",
seed=seed,
),
response_model=BriaStatusResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
response_model=BriaRemoveVideoBackgroundResponse,
)
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
class BriaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
BriaImageEditNode,
BriaRemoveImageBackground,
BriaRemoveVideoBackground,
]

View File

@@ -1,31 +1,48 @@
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.hunyuan3d import (
Hunyuan3DViewImage,
InputGenerateType,
ResultFile3D,
TextureEditTaskRequest,
To3DProTaskCreateResponse,
To3DProTaskQueryRequest,
To3DProTaskRequest,
To3DProTaskResultResponse,
To3DUVFileInput,
To3DUVTaskRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_file_3d,
download_url_to_image_tensor,
downscale_image_tensor_by_max_side,
poll_op,
sync_op,
upload_3d_model_to_comfyapi,
upload_image_to_comfyapi,
validate_image_dimensions,
validate_string,
)
def get_file_from_response(response_objs: list[ResultFile3D], file_type: str) -> ResultFile3D | None:
def _is_tencent_rate_limited(status: int, body: object) -> bool:
return (
status == 400
and isinstance(body, dict)
and "RequestLimitExceeded" in str(body.get("Response", {}).get("Error", {}).get("Code", ""))
)
def get_file_from_response(
response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
) -> ResultFile3D | None:
for i in response_objs:
if i.Type.lower() == file_type.lower():
return i
if raise_if_not_found:
raise ValueError(f"'{file_type}' file type is not found in the response.")
return None
@@ -35,7 +52,7 @@ class TencentTextToModelNode(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TencentTextToModelNode",
display_name="Hunyuan3D: Text to Model (Pro)",
display_name="Hunyuan3D: Text to Model",
category="api node/3d/Tencent",
inputs=[
IO.Combo.Input(
@@ -120,6 +137,7 @@ class TencentTextToModelNode(IO.ComfyNode):
EnablePBR=generate_type.get("pbr", None),
PolygonType=generate_type.get("polygon_type", None),
),
is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
@@ -131,11 +149,14 @@ class TencentTextToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
return IO.NodeOutput(
file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
f"{task_id}.glb",
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
),
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
),
)
@@ -145,7 +166,7 @@ class TencentImageToModelNode(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TencentImageToModelNode",
display_name="Hunyuan3D: Image(s) to Model (Pro)",
display_name="Hunyuan3D: Image(s) to Model",
category="api node/3d/Tencent",
inputs=[
IO.Combo.Input(
@@ -268,6 +289,7 @@ class TencentImageToModelNode(IO.ComfyNode):
EnablePBR=generate_type.get("pbr", None),
PolygonType=generate_type.get("polygon_type", None),
),
is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
@@ -279,11 +301,257 @@ class TencentImageToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
return IO.NodeOutput(
file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
f"{task_id}.glb",
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
),
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
),
)
class TencentModelTo3DUVNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TencentModelTo3DUVNode",
display_name="Hunyuan3D: Model to UV",
category="api node/3d/Tencent",
description="Perform UV unfolding on a 3D model to generate UV texture. "
"Input model must have less than 30000 faces.",
inputs=[
IO.MultiType.Input(
"model_3d",
types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DFBX, IO.File3DAny],
tooltip="Input 3D model (GLB, OBJ, or FBX)",
),
IO.Int.Input(
"seed",
default=1,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[
IO.File3DOBJ.Output(display_name="OBJ"),
IO.File3DFBX.Output(display_name="FBX"),
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.2}'),
)
SUPPORTED_FORMATS = {"glb", "obj", "fbx"}
@classmethod
async def execute(
cls,
model_3d: Types.File3D,
seed: int,
) -> IO.NodeOutput:
_ = seed
file_format = model_3d.format.lower()
if file_format not in cls.SUPPORTED_FORMATS:
raise ValueError(
f"Unsupported file format: '{file_format}'. "
f"Supported formats: {', '.join(sorted(cls.SUPPORTED_FORMATS))}."
)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"),
response_model=To3DProTaskCreateResponse,
data=To3DUVTaskRequest(
File=To3DUVFileInput(
Type=file_format.upper(),
Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format),
)
),
is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
result = await poll_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv/query", method="POST"),
data=To3DProTaskQueryRequest(JobId=response.JobId),
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "image").Url),
)
class Tencent3DTextureEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Tencent3DTextureEditNode",
display_name="Hunyuan3D: 3D Texture Edit",
category="api node/3d/Tencent",
description="After inputting the 3D model, perform 3D model texture redrawing.",
inputs=[
IO.MultiType.Input(
"model_3d",
types=[IO.File3DFBX, IO.File3DAny],
tooltip="3D model in FBX format. Model should have less than 100000 faces.",
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Describes texture editing. Supports up to 1024 UTF-8 characters.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[
IO.File3DGLB.Output(display_name="GLB"),
IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd": 0.6}""",
),
)
@classmethod
async def execute(
cls,
model_3d: Types.File3D,
prompt: str,
seed: int,
) -> IO.NodeOutput:
_ = seed
file_format = model_3d.format.lower()
if file_format != "fbx":
raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
validate_string(prompt, field_name="prompt", min_length=1, max_length=1024)
model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"),
response_model=To3DProTaskCreateResponse,
data=TextureEditTaskRequest(
File3D=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
Prompt=prompt,
EnablePBR=True,
),
is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
result = await poll_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit/query", method="POST"),
data=To3DProTaskQueryRequest(JobId=response.JobId),
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
)
class Tencent3DPartNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Tencent3DPartNode",
display_name="Hunyuan3D: 3D Part",
category="api node/3d/Tencent",
description="Automatically perform component identification and generation based on the model structure.",
inputs=[
IO.MultiType.Input(
"model_3d",
types=[IO.File3DFBX, IO.File3DAny],
tooltip="3D model in FBX format. Model should have less than 30000 faces.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[
IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.6}'),
)
@classmethod
async def execute(
cls,
model_3d: Types.File3D,
seed: int,
) -> IO.NodeOutput:
_ = seed
file_format = model_3d.format.lower()
if file_format != "fbx":
raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"),
response_model=To3DProTaskCreateResponse,
data=To3DUVTaskRequest(
File=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
),
is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
result = await poll_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part/query", method="POST"),
data=To3DProTaskQueryRequest(JobId=response.JobId),
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
)
@@ -293,6 +561,9 @@ class TencentHunyuan3DExtension(ComfyExtension):
return [
TencentTextToModelNode,
TencentImageToModelNode,
# TencentModelTo3DUVNode,
# Tencent3DTextureEditNode,
Tencent3DPartNode,
]

View File

@@ -43,7 +43,6 @@ class SupportedOpenAIModel(str, Enum):
o1 = "o1"
o3 = "o3"
o1_pro = "o1-pro"
gpt_4o = "gpt-4o"
gpt_4_1 = "gpt-4.1"
gpt_4_1_mini = "gpt-4.1-mini"
gpt_4_1_nano = "gpt-4.1-nano"
@@ -649,11 +648,6 @@ class OpenAIChatNode(IO.ComfyNode):
"usd": [0.01, 0.04],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-4o") ? {
"type": "list_usd",
"usd": [0.0025, 0.01],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-4.1-nano") ? {
"type": "list_usd",
"usd": [0.0001, 0.0004],

View File

@@ -33,6 +33,7 @@ from .download_helpers import (
download_url_to_video_output,
)
from .upload_helpers import (
upload_3d_model_to_comfyapi,
upload_audio_to_comfyapi,
upload_file_to_comfyapi,
upload_image_to_comfyapi,
@@ -62,6 +63,7 @@ __all__ = [
"sync_op",
"sync_op_raw",
# Upload helpers
"upload_3d_model_to_comfyapi",
"upload_audio_to_comfyapi",
"upload_file_to_comfyapi",
"upload_image_to_comfyapi",

View File

@@ -57,7 +57,7 @@ def tensor_to_bytesio(
image: torch.Tensor,
*,
total_pixels: int | None = 2048 * 2048,
mime_type: str = "image/png",
mime_type: str | None = "image/png",
) -> BytesIO:
"""Converts a torch.Tensor image to a named BytesIO object.

View File

@@ -164,6 +164,27 @@ async def upload_video_to_comfyapi(
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
_3D_MIME_TYPES = {
"glb": "model/gltf-binary",
"obj": "model/obj",
"fbx": "application/octet-stream",
}
async def upload_3d_model_to_comfyapi(
cls: type[IO.ComfyNode],
model_3d: Types.File3D,
file_format: str,
) -> str:
"""Uploads a 3D model file to ComfyUI API and returns its download URL."""
return await upload_file_to_comfyapi(
cls,
model_3d.get_data(),
f"{uuid.uuid4()}.{file_format}",
_3D_MIME_TYPES.get(file_format, "application/octet-stream"),
)
async def upload_file_to_comfyapi(
cls: type[IO.ComfyNode],
file_bytes_io: BytesIO,

View File

@@ -1,78 +0,0 @@
from typing_extensions import override
import torch
from comfy_api.latest import ComfyExtension, io, ui
def _smoothstep(edge0: float, edge1: float, x: torch.Tensor) -> torch.Tensor:
t = torch.clamp((x - edge0) / (edge1 - edge0), 0.0, 1.0)
return t * t * (3.0 - 2.0 * t)
class ColorBalanceNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ColorBalance",
display_name="Color Balance",
category="image/adjustment",
inputs=[
io.Image.Input("image"),
io.ColorBalance.Input("settings"),
],
outputs=[
io.Image.Output(),
],
)
@classmethod
def execute(cls, image: torch.Tensor, settings: dict) -> io.NodeOutput:
shadows_red = settings.get("shadows_red", 0)
shadows_green = settings.get("shadows_green", 0)
shadows_blue = settings.get("shadows_blue", 0)
midtones_red = settings.get("midtones_red", 0)
midtones_green = settings.get("midtones_green", 0)
midtones_blue = settings.get("midtones_blue", 0)
highlights_red = settings.get("highlights_red", 0)
highlights_green = settings.get("highlights_green", 0)
highlights_blue = settings.get("highlights_blue", 0)
result = image.clone().float()
# Compute per-pixel luminance
luminance = (
0.2126 * result[..., 0]
+ 0.7152 * result[..., 1]
+ 0.0722 * result[..., 2]
)
# Compute tonal range weights
shadow_weight = 1.0 - _smoothstep(0.0, 0.5, luminance)
highlight_weight = _smoothstep(0.5, 1.0, luminance)
midtone_weight = 1.0 - shadow_weight - highlight_weight
# Apply offsets per channel
for ch, (s, m, h) in enumerate([
(shadows_red, midtones_red, highlights_red),
(shadows_green, midtones_green, highlights_green),
(shadows_blue, midtones_blue, highlights_blue),
]):
offset = (
shadow_weight * (s / 100.0)
+ midtone_weight * (m / 100.0)
+ highlight_weight * (h / 100.0)
)
result[..., ch] = result[..., ch] + offset
result = torch.clamp(result, 0, 1)
return io.NodeOutput(result, ui=ui.PreviewImage(result))
class ColorBalanceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [ColorBalanceNode]
async def comfy_entrypoint() -> ColorBalanceExtension:
return ColorBalanceExtension()

View File

@@ -1,88 +0,0 @@
from typing_extensions import override
import torch
import numpy as np
from comfy_api.latest import ComfyExtension, io, ui
class ColorCorrectNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ColorCorrect",
display_name="Color Correct",
category="image/adjustment",
inputs=[
io.Image.Input("image"),
io.ColorCorrect.Input("settings"),
],
outputs=[
io.Image.Output(),
],
)
@classmethod
def execute(cls, image: torch.Tensor, settings: dict) -> io.NodeOutput:
temperature = settings.get("temperature", 0)
hue = settings.get("hue", 0)
brightness = settings.get("brightness", 0)
contrast = settings.get("contrast", 0)
saturation = settings.get("saturation", 0)
gamma = settings.get("gamma", 1.0)
result = image.clone()
# Brightness: scale RGB values
if brightness != 0:
factor = 1.0 + brightness / 100.0
result = result * factor
# Contrast: adjust around midpoint
if contrast != 0:
factor = 1.0 + contrast / 100.0
mean = result[..., :3].mean()
result[..., :3] = (result[..., :3] - mean) * factor + mean
# Temperature: shift warm (red+) / cool (blue+)
if temperature != 0:
temp_factor = temperature / 100.0
result[..., 0] = result[..., 0] + temp_factor * 0.1 # Red
result[..., 2] = result[..., 2] - temp_factor * 0.1 # Blue
# Gamma correction
if gamma != 1.0:
result[..., :3] = torch.pow(torch.clamp(result[..., :3], 0, 1), 1.0 / gamma)
# Saturation: convert to HSV-like space
if saturation != 0:
factor = 1.0 + saturation / 100.0
gray = result[..., :3].mean(dim=-1, keepdim=True)
result[..., :3] = gray + (result[..., :3] - gray) * factor
# Hue rotation: rotate in RGB space using rotation matrix
if hue != 0:
angle = np.radians(hue)
cos_a = np.cos(angle)
sin_a = np.sin(angle)
# Rodrigues' rotation formula around (1,1,1)/sqrt(3) axis
k = 1.0 / 3.0
rotation = torch.tensor([
[cos_a + k * (1 - cos_a), k * (1 - cos_a) - sin_a / np.sqrt(3), k * (1 - cos_a) + sin_a / np.sqrt(3)],
[k * (1 - cos_a) + sin_a / np.sqrt(3), cos_a + k * (1 - cos_a), k * (1 - cos_a) - sin_a / np.sqrt(3)],
[k * (1 - cos_a) - sin_a / np.sqrt(3), k * (1 - cos_a) + sin_a / np.sqrt(3), cos_a + k * (1 - cos_a)]
], dtype=result.dtype, device=result.device)
rgb = result[..., :3]
result[..., :3] = torch.matmul(rgb, rotation.T)
result = torch.clamp(result, 0, 1)
return io.NodeOutput(result, ui=ui.PreviewImage(result))
class ColorCorrectExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [ColorCorrectNode]
async def comfy_entrypoint() -> ColorCorrectExtension:
return ColorCorrectExtension()

View File

@@ -1,137 +0,0 @@
from typing_extensions import override
import torch
import numpy as np
from comfy_api.latest import ComfyExtension, io, ui
def _monotone_cubic_hermite(xs, ys, x_query):
"""Evaluate monotone cubic Hermite interpolation at x_query points."""
n = len(xs)
if n == 0:
return np.zeros_like(x_query)
if n == 1:
return np.full_like(x_query, ys[0])
# Compute slopes
deltas = np.diff(ys) / np.maximum(np.diff(xs), 1e-10)
# Compute tangents (Fritsch-Carlson)
slopes = np.zeros(n)
slopes[0] = deltas[0]
slopes[-1] = deltas[-1]
for i in range(1, n - 1):
if deltas[i - 1] * deltas[i] <= 0:
slopes[i] = 0
else:
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
# Enforce monotonicity
for i in range(n - 1):
if deltas[i] == 0:
slopes[i] = 0
slopes[i + 1] = 0
else:
alpha = slopes[i] / deltas[i]
beta = slopes[i + 1] / deltas[i]
s = alpha ** 2 + beta ** 2
if s > 9:
t = 3 / np.sqrt(s)
slopes[i] = t * alpha * deltas[i]
slopes[i + 1] = t * beta * deltas[i]
# Evaluate
result = np.zeros_like(x_query, dtype=np.float64)
indices = np.searchsorted(xs, x_query, side='right') - 1
indices = np.clip(indices, 0, n - 2)
for i in range(n - 1):
mask = indices == i
if not np.any(mask):
continue
dx = xs[i + 1] - xs[i]
if dx == 0:
result[mask] = ys[i]
continue
t = (x_query[mask] - xs[i]) / dx
t2 = t * t
t3 = t2 * t
h00 = 2 * t3 - 3 * t2 + 1
h10 = t3 - 2 * t2 + t
h01 = -2 * t3 + 3 * t2
h11 = t3 - t2
result[mask] = h00 * ys[i] + h10 * dx * slopes[i] + h01 * ys[i + 1] + h11 * dx * slopes[i + 1]
# Clamp edges
result[x_query <= xs[0]] = ys[0]
result[x_query >= xs[-1]] = ys[-1]
return result
def _build_lut(points):
"""Build a 256-entry LUT from curve control points in [0,1] space."""
if not points or len(points) < 2:
return np.arange(256, dtype=np.float64) / 255.0
pts = sorted(points, key=lambda p: p[0])
xs = np.array([p[0] for p in pts], dtype=np.float64)
ys = np.array([p[1] for p in pts], dtype=np.float64)
x_query = np.linspace(0, 1, 256)
lut = _monotone_cubic_hermite(xs, ys, x_query)
return np.clip(lut, 0, 1)
class ColorCurvesNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ColorCurves",
display_name="Color Curves",
category="image/adjustment",
inputs=[
io.Image.Input("image"),
io.ColorCurves.Input("settings"),
],
outputs=[
io.Image.Output(),
],
)
@classmethod
def execute(cls, image: torch.Tensor, settings: dict) -> io.NodeOutput:
rgb_pts = settings.get("rgb", [[0, 0], [1, 1]])
red_pts = settings.get("red", [[0, 0], [1, 1]])
green_pts = settings.get("green", [[0, 0], [1, 1]])
blue_pts = settings.get("blue", [[0, 0], [1, 1]])
rgb_lut = _build_lut(rgb_pts)
red_lut = _build_lut(red_pts)
green_lut = _build_lut(green_pts)
blue_lut = _build_lut(blue_pts)
# Convert to numpy for LUT application
img_np = image.cpu().numpy().copy()
# Apply per-channel curves then RGB master curve
for ch, ch_lut in enumerate([red_lut, green_lut, blue_lut]):
# Per-channel curve
indices = np.clip(img_np[..., ch] * 255, 0, 255).astype(np.int32)
img_np[..., ch] = ch_lut[indices]
# RGB master curve
indices = np.clip(img_np[..., ch] * 255, 0, 255).astype(np.int32)
img_np[..., ch] = rgb_lut[indices]
result = torch.from_numpy(np.clip(img_np, 0, 1)).to(image.device, dtype=image.dtype)
return io.NodeOutput(result, ui=ui.PreviewImage(result))
class ColorCurvesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [ColorCurvesNode]
async def comfy_entrypoint() -> ColorCurvesExtension:
return ColorCurvesExtension()

View File

@@ -23,9 +23,8 @@ class ImageCrop(IO.ComfyNode):
return IO.Schema(
node_id="ImageCrop",
search_aliases=["trim"],
display_name="Image Crop (Deprecated)",
display_name="Image Crop",
category="image/transform",
is_deprecated=True,
inputs=[
IO.Image.Input("image"),
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
@@ -48,57 +47,6 @@ class ImageCrop(IO.ComfyNode):
crop = execute # TODO: remove
class ImageCropV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ImageCropV2",
search_aliases=["trim"],
display_name="Image Crop",
category="image/transform",
inputs=[
IO.Image.Input("image"),
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
],
outputs=[IO.Image.Output()],
)
@classmethod
def execute(cls, image, crop_region) -> IO.NodeOutput:
x = crop_region.get("x", 0)
y = crop_region.get("y", 0)
width = crop_region.get("width", 512)
height = crop_region.get("height", 512)
x = min(x, image.shape[2] - 1)
y = min(y, image.shape[1] - 1)
to_x = width + x
to_y = height + y
img = image[:,y:to_y, x:to_x, :]
return IO.NodeOutput(img)
class BoundingBox(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="PrimitiveBoundingBox",
display_name="Bounding Box",
category="utils/primitive",
inputs=[
IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION),
IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION),
IO.Int.Input("width", default=512, min=1, max=MAX_RESOLUTION),
IO.Int.Input("height", default=512, min=1, max=MAX_RESOLUTION),
],
outputs=[IO.BoundingBox.Output()],
)
@classmethod
def execute(cls, x, y, width, height) -> IO.NodeOutput:
return IO.NodeOutput({"x": x, "y": y, "width": width, "height": height})
class RepeatImageBatch(IO.ComfyNode):
@classmethod
def define_schema(cls):
@@ -684,8 +632,6 @@ class ImagesExtension(ComfyExtension):
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ImageCrop,
ImageCropV2,
BoundingBox,
RepeatImageBatch,
ImageFromBatch,
ImageAddNoise,

View File

@@ -7,6 +7,7 @@ import logging
from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from tqdm.auto import trange
CLAMP_QUANTILE = 0.99
@@ -49,12 +50,22 @@ LORA_TYPES = {"standard": LORAType.STANDARD,
"full_diff": LORAType.FULL_DIFF}
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
comfy.model_management.load_models_gpu([model_diff])
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
for k in sd:
if k.endswith(".weight"):
sd_keys = list(sd.keys())
for index in trange(len(sd_keys), unit="weight"):
k = sd_keys[index]
op_keys = sd_keys[index].rsplit('.', 1)
if len(op_keys) < 2 or op_keys[1] not in ["weight", "bias"] or (op_keys[1] == "bias" and not bias_diff):
continue
op = comfy.utils.get_attr(model_diff.model, op_keys[0])
if hasattr(op, "comfy_cast_weights") and not getattr(op, "comfy_patched_weights", False):
weight_diff = model_diff.patch_weight_to_device(k, model_diff.load_device, return_weight=True)
else:
weight_diff = sd[k]
if op_keys[1] == "weight":
if lora_type == LORAType.STANDARD:
if weight_diff.ndim < 2:
if bias_diff:
@@ -69,8 +80,8 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
elif lora_type == LORAType.FULL_DIFF:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
elif bias_diff and k.endswith(".bias"):
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
elif bias_diff and op_keys[1] == "bias":
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = weight_diff.contiguous().half().cpu()
return output_sd
class LoraSave(io.ComfyNode):

View File

@@ -655,6 +655,7 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
batched = batch_masks(values)
return io.NodeOutput(batched)
class PostProcessingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:

View File

@@ -0,0 +1,103 @@
from comfy_api.latest import ComfyExtension, io, ComfyAPI
api = ComfyAPI()
async def register_replacements():
"""Register all built-in node replacements."""
await register_replacements_longeredge()
await register_replacements_batchimages()
await register_replacements_upscaleimage()
await register_replacements_controlnet()
await register_replacements_load3d()
await register_replacements_preview3d()
await register_replacements_svdimg2vid()
await register_replacements_conditioningavg()
async def register_replacements_longeredge():
# No dynamic inputs here
await api.node_replacement.register(io.NodeReplace(
new_node_id="ImageScaleToMaxDimension",
old_node_id="ResizeImagesByLongerEdge",
old_widget_ids=["longer_edge"],
input_mapping=[
{"new_id": "image", "old_id": "images"},
{"new_id": "largest_size", "old_id": "longer_edge"},
{"new_id": "upscale_method", "set_value": "lanczos"},
],
# just to test the frontend output_mapping code, does nothing really here
output_mapping=[{"new_idx": 0, "old_idx": 0}],
))
async def register_replacements_batchimages():
# BatchImages node uses Autogrow
await api.node_replacement.register(io.NodeReplace(
new_node_id="BatchImagesNode",
old_node_id="ImageBatch",
input_mapping=[
{"new_id": "images.image0", "old_id": "image1"},
{"new_id": "images.image1", "old_id": "image2"},
],
))
async def register_replacements_upscaleimage():
# ResizeImageMaskNode uses DynamicCombo
await api.node_replacement.register(io.NodeReplace(
new_node_id="ResizeImageMaskNode",
old_node_id="ImageScaleBy",
old_widget_ids=["upscale_method", "scale_by"],
input_mapping=[
{"new_id": "input", "old_id": "image"},
{"new_id": "resize_type", "set_value": "scale by multiplier"},
{"new_id": "resize_type.multiplier", "old_id": "scale_by"},
{"new_id": "scale_method", "old_id": "upscale_method"},
],
))
async def register_replacements_controlnet():
# T2IAdapterLoader → ControlNetLoader
await api.node_replacement.register(io.NodeReplace(
new_node_id="ControlNetLoader",
old_node_id="T2IAdapterLoader",
input_mapping=[
{"new_id": "control_net_name", "old_id": "t2i_adapter_name"},
],
))
async def register_replacements_load3d():
# Load3DAnimation merged into Load3D
await api.node_replacement.register(io.NodeReplace(
new_node_id="Load3D",
old_node_id="Load3DAnimation",
))
async def register_replacements_preview3d():
# Preview3DAnimation merged into Preview3D
await api.node_replacement.register(io.NodeReplace(
new_node_id="Preview3D",
old_node_id="Preview3DAnimation",
))
async def register_replacements_svdimg2vid():
# Typo fix: SDV → SVD
await api.node_replacement.register(io.NodeReplace(
new_node_id="SVD_img2vid_Conditioning",
old_node_id="SDV_img2vid_Conditioning",
))
async def register_replacements_conditioningavg():
# Typo fix: trailing space in node name
await api.node_replacement.register(io.NodeReplace(
new_node_id="ConditioningAverage",
old_node_id="ConditioningAverage ",
))
class NodeReplacementsExtension(ComfyExtension):
async def on_load(self) -> None:
await register_replacements()
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return []
async def comfy_entrypoint() -> NodeReplacementsExtension:
return NodeReplacementsExtension()

View File

@@ -1035,7 +1035,7 @@ class TrainLoraNode(io.ComfyNode):
io.Boolean.Input(
"offloading",
default=False,
tooltip="Depth level for gradient checkpointing.",
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
),
io.Combo.Input(
"existing_lora",
@@ -1124,6 +1124,15 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
if mp.is_dynamic():
if not bypass_mode:
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
bypass_mode = True
offloading = True
elif offloading:
if not bypass_mode:
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
# Prepare latents and compute counts
latents, num_images, multi_res = _prepare_latents_and_count(
latents, dtype, bucket_mode

View File

@@ -2264,6 +2264,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
if not isinstance(extension, ComfyExtension):
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
return False
await extension.on_load()
node_list = await extension.get_node_list()
if not isinstance(node_list, list):
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
@@ -2435,9 +2436,7 @@ async def init_builtin_extra_nodes():
"nodes_lora_debug.py",
"nodes_color.py",
"nodes_toolkit.py",
"nodes_color_correct.py",
"nodes_color_balance.py",
"nodes_color_curves.py"
"nodes_replacements.py",
]
import_failed = []

View File

@@ -1,4 +1,4 @@
comfyui-frontend-package==1.38.13
comfyui-frontend-package==1.38.14
comfyui-workflow-templates==0.8.38
comfyui-embedded-docs==0.4.1
torch

View File

@@ -40,6 +40,7 @@ from app.user_manager import UserManager
from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from app.subgraph_manager import SubgraphManager
from app.node_replace_manager import NodeReplaceManager
from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
from protocol import BinaryEventTypes
@@ -204,6 +205,7 @@ class PromptServer():
self.model_file_manager = ModelFileManager()
self.custom_node_manager = CustomNodeManager()
self.subgraph_manager = SubgraphManager()
self.node_replace_manager = NodeReplaceManager()
self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = execution.PromptQueue(self)
@@ -887,6 +889,8 @@ class PromptServer():
if "partial_execution_targets" in json_data:
partial_execution_targets = json_data["partial_execution_targets"]
self.node_replace_manager.apply_replacements(prompt)
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
extra_data = {}
if "extra_data" in json_data:
@@ -995,6 +999,7 @@ class PromptServer():
self.model_file_manager.add_routes(self.routes)
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
self.node_replace_manager.add_routes(self.routes)
self.app.add_subapp('/internal', self.internal_routes.get_app())
# Prefix every route with /api for easier matching for delegation.