Compare commits

..

2 Commits

Author SHA1 Message Date
Jedrzej Kosinski
5907e21dbc Merge branch 'master' into jk/requirements-files 2026-02-11 23:00:16 -08:00
Jedrzej Kosinski
2058a3a69c Added requirements files for different hardware vendors 2026-01-30 01:45:57 -08:00
27 changed files with 307 additions and 703 deletions

View File

@@ -227,7 +227,7 @@ Put your VAE in: models/vae
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:

View File

@@ -297,30 +297,6 @@ class ControlNet(ControlBase):
self.model_sampling_current = None
super().cleanup()
class QwenFunControlNet(ControlNet):
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
# Fun checkpoints are more sensitive to high strengths in the generic
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
# unchanged while >1 grows more gently.
original_strength = self.strength
self.strength = math.sqrt(max(self.strength, 0.0))
try:
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
finally:
self.strength = original_strength
def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
self.set_extra_arg("base_model", model.diffusion_model)
def copy(self):
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c
class ControlLoraOps:
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
@@ -584,7 +560,6 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
sd = model_config.process_unet_state_dict(sd)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
@@ -630,53 +605,6 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet_qwen_fun(sd, model_options={}):
load_device = comfy.model_management.get_torch_device()
weight_dtype = comfy.utils.weight_dtype(sd)
unet_dtype = model_options.get("dtype", weight_dtype)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
in_features = sd["control_img_in.weight"].shape[1]
inner_dim = sd["control_img_in.weight"].shape[0]
block_weight = sd["control_blocks.0.attn.to_q.weight"]
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
control_in_features=in_features,
inner_dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_control_blocks=5,
main_model_double=60,
injection_layers=(0, 12, 24, 36, 48),
operations=operations,
device=comfy.model_management.unet_offload_device(),
dtype=unet_dtype,
)
model = controlnet_load_state_dict(model, sd)
latent_format = comfy.latent_formats.Wan21()
control = QwenFunControlNet(
model,
compression_ratio=1,
latent_format=latent_format,
# Fun checkpoints already expect their own 33-channel context handling.
# Enabling generic concat_mask injects an extra mask channel at apply-time
# and breaks the intended fallback packing path.
concat_mask=False,
load_device=load_device,
manual_cast_dtype=manual_cast_dtype,
extra_conds=[],
)
return control
def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
@@ -754,8 +682,6 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)

View File

@@ -3,6 +3,7 @@ from torch import Tensor, nn
from comfy.ldm.flux.layers import (
MLPEmbedder,
RMSNorm,
ModulationOut,
)
@@ -28,7 +29,7 @@ class Approximator(nn.Module):
super().__init__()
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
@property

View File

@@ -4,6 +4,8 @@ from functools import lru_cache
import torch
from torch import nn
from comfy.ldm.flux.layers import RMSNorm
class NerfEmbedder(nn.Module):
"""
@@ -143,7 +145,7 @@ class NerfGLUBlock(nn.Module):
# We now need to generate parameters for 3 matrices.
total_params = 3 * hidden_size_x**2 * mlp_ratio
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
self.mlp_ratio = mlp_ratio
@@ -176,7 +178,7 @@ class NerfGLUBlock(nn.Module):
class NerfFinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -188,7 +190,7 @@ class NerfFinalLayer(nn.Module):
class NerfFinalLayerConv(nn.Module):
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.conv = operations.Conv2d(
in_channels=hidden_size,
out_channels=out_channels,

View File

@@ -5,9 +5,9 @@ import torch
from torch import Tensor, nn
from .math import attention, rope
import comfy.ops
import comfy.ldm.common_dit
# Fix import for some custom nodes, TODO: delete eventually.
RMSNorm = None
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list):
@@ -87,12 +87,20 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor):
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
class QKNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
q = self.query_norm(q)
@@ -161,7 +169,7 @@ class SiLUActivation(nn.Module):
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@@ -189,6 +197,8 @@ class DoubleStreamBlock(nn.Module):
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
if self.modulation:
img_mod1, img_mod2 = self.img_mod(vec)
@@ -214,17 +224,32 @@ class DoubleStreamBlock(nn.Module):
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# run actual attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
if self.flipped_img_txt:
q = torch.cat((img_q, txt_q), dim=2)
del img_q, txt_q
k = torch.cat((img_k, txt_k), dim=2)
del img_k, txt_k
v = torch.cat((img_v, txt_v), dim=2)
del img_v, txt_v
# run actual attention
attn = attention(q, k, v,
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# run actual attention
attn = attention(q, k, v,
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)

View File

@@ -16,6 +16,7 @@ from .layers import (
SingleStreamBlock,
timestep_embedding,
Modulation,
RMSNorm
)
@dataclass
@@ -80,7 +81,7 @@ class Flux(nn.Module):
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
if params.txt_norm:
self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
else:
self.txt_norm = None

View File

@@ -241,6 +241,7 @@ class HunyuanVideo(nn.Module):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
flipped_img_txt=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -377,14 +378,14 @@ class HunyuanVideo(nn.Module):
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
ids = torch.cat((txt_ids, img_ids), dim=1)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
img_len = img.shape[1]
if txt_mask is not None:
attn_mask_len = img_len + txt.shape[1]
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
attn_mask[:, 0, :txt.shape[1]] = txt_mask
attn_mask[:, 0, img_len:] = txt_mask
else:
attn_mask = None
@@ -412,7 +413,7 @@ class HunyuanVideo(nn.Module):
if add is not None:
img += add
img = torch.cat((txt, img), 1)
img = torch.cat((img, txt), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
@@ -434,9 +435,9 @@ class HunyuanVideo(nn.Module):
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1]: img_len + txt.shape[1]] += add
img[:, : img_len] += add
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
img = img[:, : img_len]
if ref_latent is not None:
img = img[:, ref_latent.shape[1]:]

View File

@@ -2,196 +2,6 @@ import torch
import math
from .model import QwenImageTransformer2DModel
from .model import QwenImageTransformerBlock
class QwenImageFunControlBlock(QwenImageTransformerBlock):
def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None):
super().__init__(
dim=dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dtype=dtype,
device=device,
operations=operations,
)
self.has_before_proj = has_before_proj
if has_before_proj:
self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
class QwenImageFunControlNetModel(torch.nn.Module):
def __init__(
self,
control_in_features=132,
inner_dim=3072,
num_attention_heads=24,
attention_head_dim=128,
num_control_blocks=5,
main_model_double=60,
injection_layers=(0, 12, 24, 36, 48),
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dtype = dtype
self.main_model_double = main_model_double
self.injection_layers = tuple(injection_layers)
# Keep base hint scaling at 1.0 so user-facing strength behaves similarly
# to the reference Gen2/VideoX implementation around strength=1.
self.hint_scale = 1.0
self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype)
self.control_blocks = torch.nn.ModuleList([])
for i in range(num_control_blocks):
self.control_blocks.append(
QwenImageFunControlBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
has_before_proj=(i == 0),
dtype=dtype,
device=device,
operations=operations,
)
)
def _process_hint_tokens(self, hint):
if hint is None:
return None
if hint.ndim == 4:
hint = hint.unsqueeze(2)
# Fun checkpoints are trained with 33 latent channels before 2x2 packing:
# [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features.
# Default behavior (no inpaint input in stock Apply ControlNet) should use
# zeros for mask/inpaint branches, matching VideoX fallback semantics.
expected_c = self.control_img_in.weight.shape[1] // 4
if hint.shape[1] == 16 and expected_c == 33:
zeros_mask = torch.zeros_like(hint[:, :1])
zeros_inpaint = torch.zeros_like(hint)
hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1)
bs, c, t, h, w = hint.shape
hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(
orig_shape[0],
orig_shape[1],
orig_shape[-3],
orig_shape[-2] // 2,
2,
orig_shape[-1] // 2,
2,
)
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
hidden_states = hidden_states.reshape(
bs,
t * ((h + 1) // 2) * ((w + 1) // 2),
c * 4,
)
expected_in = self.control_img_in.weight.shape[1]
cur_in = hidden_states.shape[-1]
if cur_in < expected_in:
pad = torch.zeros(
(hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
hidden_states = torch.cat([hidden_states, pad], dim=-1)
elif cur_in > expected_in:
hidden_states = hidden_states[:, :, :expected_in]
return hidden_states
def forward(
self,
x,
timesteps,
context,
attention_mask=None,
guidance: torch.Tensor = None,
hint=None,
transformer_options={},
base_model=None,
**kwargs,
):
if base_model is None:
raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.")
encoder_hidden_states_mask = attention_mask
# Keep attention mask disabled inside Fun control blocks to mirror
# VideoX behavior (they rely on seq lengths for RoPE, not masked attention).
encoder_hidden_states_mask = None
hidden_states, img_ids, _ = base_model.process_img(x)
hint_tokens = self._process_hint_tokens(hint)
if hint_tokens is None:
raise RuntimeError("Qwen Fun ControlNet requires a control hint image.")
if hint_tokens.shape[1] != hidden_states.shape[1]:
max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1])
hint_tokens = hint_tokens[:, :max_tokens]
hidden_states = hidden_states[:, :max_tokens]
img_ids = img_ids[:, :max_tokens]
txt_start = round(
max(
((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
)
)
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous()
hidden_states = base_model.img_in(hidden_states)
encoder_hidden_states = base_model.txt_norm(context)
encoder_hidden_states = base_model.txt_in(encoder_hidden_states)
if guidance is not None:
guidance = guidance * 1000
temb = (
base_model.time_text_embed(timesteps, hidden_states)
if guidance is None
else base_model.time_text_embed(timesteps, guidance, hidden_states)
)
c = self.control_img_in(hint_tokens)
for i, block in enumerate(self.control_blocks):
if i == 0:
c_in = block.before_proj(c) + hidden_states
all_c = []
else:
all_c = list(torch.unbind(c, dim=0))
c_in = all_c.pop(-1)
encoder_hidden_states, c_out = block(
hidden_states=c_in,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
c_skip = block.after_proj(c_out) * self.hint_scale
all_c += [c_skip, c_out]
c = torch.stack(all_c, dim=0)
hints = torch.unbind(c, dim=0)[:-1]
controlnet_block_samples = [None] * self.main_model_double
for local_idx, base_idx in enumerate(self.injection_layers):
if local_idx < len(hints) and base_idx < len(controlnet_block_samples):
controlnet_block_samples[base_idx] = hints[local_idx]
return {"input": controlnet_block_samples}
class QwenImageControlNetModel(QwenImageTransformer2DModel):

View File

@@ -5,7 +5,7 @@ import comfy.utils
def convert_lora_bfl_control(sd): #BFL loras for Flux
sd_out = {}
for k in sd:
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight"))
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
sd_out[k_to] = sd[k]
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])

View File

@@ -19,12 +19,6 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1
return count
def any_suffix_in(keys, prefix, main, suffix_list=[]):
for x in suffix_list:
if "{}{}{}".format(prefix, main, x) in keys:
return True
return False
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
context_dim = None
use_linear_in_transformer = False
@@ -192,7 +186,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["meanflow_sum"] = False
return dit_config
if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
dit_config = {}
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
dit_config["image_model"] = "flux2"
@@ -247,8 +241,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
dit_config["image_model"] = "chroma"
dit_config["in_channels"] = 64
dit_config["out_channels"] = 64
@@ -256,8 +249,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["out_dim"] = 3072
dit_config["hidden_dim"] = 5120
dit_config["n_layers"] = 5
if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
dit_config["image_model"] = "chroma_radiance"
dit_config["in_channels"] = 3
dit_config["out_channels"] = 3
@@ -267,7 +259,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_depth"] = 4
dit_config["nerf_max_freqs"] = 8
dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear"
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
@@ -276,7 +268,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
dit_config["txt_ids_dims"] = [1, 2]

View File

@@ -679,19 +679,18 @@ class ModelPatcher:
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
def _load_list(self, prio_comfy_cast_weights=False):
loading = []
for n, m in self.model.named_modules():
default = False
params = { name: param for name, param in m.named_parameters(recurse=False) }
params = []
skip = False
for name, param in m.named_parameters(recurse=False):
params.append(name)
for name, param in m.named_parameters(recurse=True):
if name not in params:
default = True # default random weights in non leaf modules
skip = True # skip random weights in non leaf modules
break
if default and default_device is not None:
for param in params.values():
param.data = param.data.to(device=default_device)
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
@@ -1496,7 +1495,7 @@ class ModelPatcherDynamic(ModelPatcher):
#with pin and unpin syncrhonization which can be expensive for small weights
#with a high layer rate (e.g. autoregressive LLMs).
#prioritize the non-comfy weights (note the order reverse).
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
loading = self._load_list(prio_comfy_cast_weights=True)
loading.sort(reverse=True)
for x in loading:
@@ -1561,8 +1560,6 @@ class ModelPatcherDynamic(ModelPatcher):
allocated_size += weight_size
vbar.set_watermark_limit(allocated_size)
move_weight_functions(m, device_to)
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
self.model.device = device_to
@@ -1582,7 +1579,7 @@ class ModelPatcherDynamic(ModelPatcher):
return 0 if vbar is None else vbar.free_memory(memory_to_free)
def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
loading = self._load_list(prio_comfy_cast_weights=True)
for x in loading:
_, _, _, _, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
@@ -1603,8 +1600,6 @@ class ModelPatcherDynamic(ModelPatcher):
if unpatch_weights:
self.partially_unload_ram(1e32)
self.partially_unload(None, 1e32)
for m in self.model.modules():
move_weight_functions(m, device_to)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above

View File

@@ -171,9 +171,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def process_tokens(self, tokens, device):
end_token = self.special_tokens.get("end", None)
pad_token = self.special_tokens.get("pad", -1)
if end_token is None:
cmp_token = pad_token
cmp_token = self.special_tokens.get("pad", -1)
else:
cmp_token = end_token
@@ -187,21 +186,15 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
other_embeds = []
eos = False
index = 0
left_pad = False
for y in x:
if isinstance(y, numbers.Integral):
token = int(y)
if index == 0 and token == pad_token:
left_pad = True
if eos or (left_pad and token == pad_token):
if eos:
attention_mask.append(0)
else:
attention_mask.append(1)
left_pad = False
token = int(y)
tokens_temp += [token]
if not eos and token == cmp_token and not left_pad:
if not eos and token == cmp_token:
if end_token is None:
attention_mask[-1] = 0
eos = True

View File

@@ -710,15 +710,6 @@ class Flux(supported_models_base.BASE):
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
if key_out.endswith("_norm.scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
@@ -907,13 +898,11 @@ class HunyuanVideo(supported_models_base.BASE):
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight")
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight")
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
if key_out.endswith(".scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
@@ -1275,15 +1264,6 @@ class Hunyuan3Dv2(supported_models_base.BASE):
latent_format = latent_formats.Hunyuan3Dv2
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
if key_out.endswith(".scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
@@ -1361,14 +1341,6 @@ class Chroma(supported_models_base.BASE):
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
if key_out.endswith(".scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Chroma(self, device=device)

View File

@@ -10,6 +10,7 @@ import comfy.utils
def sample_manual_loop_no_classes(
model,
ids=None,
paddings=[],
execution_dtype=None,
cfg_scale: float = 2.0,
temperature: float = 0.85,
@@ -35,6 +36,9 @@ def sample_manual_loop_no_classes(
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
embeds_batch = embeds.shape[0]
for i, t in enumerate(paddings):
attention_mask[i, :t] = 0
attention_mask[i, t:] = 1
output_audio_codes = []
past_key_values = []
@@ -131,11 +135,13 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
pos_pad = (len(negative) - len(positive))
positive = [model.special_tokens["pad"]] * pos_pad + positive
paddings = [pos_pad, neg_pad]
ids = [positive, negative]
else:
paddings = []
ids = [positive]
return sample_manual_loop_no_classes(model, ids, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):

View File

@@ -355,6 +355,13 @@ class RMSNorm(nn.Module):
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
if not isinstance(theta, list):
theta = [theta]
@@ -383,30 +390,20 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
sin_split = sin.shape[-1] // 2
out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
out.append((cos, sin))
if len(out) == 1:
return out[0]
return out
def apply_rope(xq, xk, freqs_cis):
org_dtype = xq.dtype
cos = freqs_cis[0]
sin = freqs_cis[1]
nsin = freqs_cis[2]
q_embed = (xq * cos)
q_split = q_embed.shape[-1] // 2
q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin)
q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin)
k_embed = (xk * cos)
k_split = k_embed.shape[-1] // 2
k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin)
k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin)
q_embed = (xq * cos) + (rotate_half(xq) * sin)
k_embed = (xk * cos) + (rotate_half(xk) * sin)
return q_embed.to(org_dtype), k_embed.to(org_dtype)

View File

@@ -25,7 +25,7 @@ def ltxv_te(*args, **kwargs):
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
@@ -97,7 +97,6 @@ class LTXAVTEModel(torch.nn.Module):
token_weight_pairs = token_weight_pairs["gemma3_12b"]
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
out = out[:, :, -torch.sum(extra["attention_mask"]).item():]
out_device = out.device
if comfy.model_management.should_use_bf16(self.execution_device):
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
@@ -139,7 +138,6 @@ class LTXAVTEModel(torch.nn.Module):
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
num_tokens = max(num_tokens, 64)
return num_tokens * constant * 1024 * 1024
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):

View File

@@ -675,10 +675,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"ff_context.linear_in.bias": "txt_mlp.0.bias",
"ff_context.linear_out.weight": "txt_mlp.2.weight",
"ff_context.linear_out.bias": "txt_mlp.2.bias",
"attn.norm_q.weight": "img_attn.norm.query_norm.weight",
"attn.norm_k.weight": "img_attn.norm.key_norm.weight",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
}
for k in block_map:
@@ -701,8 +701,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"norm.linear.bias": "modulation.lin.bias",
"proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias",
"attn.norm_q.weight": "norm.query_norm.weight",
"attn.norm_k.weight": "norm.key_norm.weight",
"attn.norm_q.weight": "norm.query_norm.scale",
"attn.norm_k.weight": "norm.key_norm.scale",
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
"attn.to_out.weight": "linear2.weight", # Flux 2
}

View File

@@ -57,7 +57,6 @@ class _RequestConfig:
files: dict[str, Any] | list[tuple[str, Any]] | None
multipart_parser: Callable | None
max_retries: int
max_retries_on_rate_limit: int
retry_delay: float
retry_backoff: float
wait_label: str = "Waiting"
@@ -66,7 +65,6 @@ class _RequestConfig:
final_label_on_success: str | None = "Completed"
progress_origin_ts: float | None = None
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
is_rate_limited: Callable[[int, Any], bool] | None = None
@dataclass
@@ -80,7 +78,7 @@ class _PollUIState:
active_since: float | None = None # start time of current active interval (None if queued)
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
@@ -105,8 +103,6 @@ async def sync_op(
final_label_on_success: str | None = "Completed",
progress_origin_ts: float | None = None,
monitor_progress: bool = True,
max_retries_on_rate_limit: int = 16,
is_rate_limited: Callable[[int, Any], bool] | None = None,
) -> M:
raw = await sync_op_raw(
cls,
@@ -126,8 +122,6 @@ async def sync_op(
final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts,
monitor_progress=monitor_progress,
max_retries_on_rate_limit=max_retries_on_rate_limit,
is_rate_limited=is_rate_limited,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
@@ -200,8 +194,6 @@ async def sync_op_raw(
final_label_on_success: str | None = "Completed",
progress_origin_ts: float | None = None,
monitor_progress: bool = True,
max_retries_on_rate_limit: int = 16,
is_rate_limited: Callable[[int, Any], bool] | None = None,
) -> dict[str, Any] | bytes:
"""
Make a single network request.
@@ -230,8 +222,6 @@ async def sync_op_raw(
final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts,
price_extractor=price_extractor,
max_retries_on_rate_limit=max_retries_on_rate_limit,
is_rate_limited=is_rate_limited,
)
return await _request_base(cfg, expect_binary=as_binary)
@@ -516,7 +506,7 @@ def _friendly_http_message(status: int, body: Any) -> str:
if status == 409:
return "There is a problem with your account. Please contact support@comfy.org."
if status == 429:
return "Rate Limit Exceeded: The server returned 429 after all retry attempts. Please wait and try again."
return "Rate Limit Exceeded: Please try again later."
try:
if isinstance(body, dict):
err = body.get("error")
@@ -596,8 +586,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
attempt = 0
delay = cfg.retry_delay
rate_limit_attempts = 0
rate_limit_delay = cfg.retry_delay
operation_succeeded: bool = False
final_elapsed_seconds: int | None = None
extracted_price: float | None = None
@@ -665,14 +653,17 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload_headers["Content-Type"] = "application/json"
payload_kw["json"] = cfg.data or {}
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
)
except Exception as _log_e:
logging.debug("[DEBUG] request logging failed: %s", _log_e)
req_coro = sess.request(method, url, params=params, **payload_kw)
req_task = asyncio.create_task(req_coro)
@@ -697,33 +688,41 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
body = await resp.json()
except (ContentTypeError, json.JSONDecodeError):
body = await resp.text()
should_retry = False
wait_time = 0.0
retry_label = ""
is_rl = resp.status == 429 or (
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
)
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
rate_limit_attempts += 1
wait_time = min(rate_limit_delay, 30.0)
rate_limit_delay *= cfg.retry_backoff
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
should_retry = True
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
wait_time = delay
delay *= cfg.retry_backoff
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
should_retry = True
if should_retry:
if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
logging.warning(
"HTTP %s %s -> %s. Waiting %.2fs (%s).",
"HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
method,
url,
resp.status,
wait_time,
retry_label,
delay,
attempt,
cfg.max_retries,
)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=_friendly_http_message(resp.status, body),
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
await sleep_with_interrupt(
delay,
cfg.node_cls,
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
)
delay *= cfg.retry_backoff
continue
msg = _friendly_http_message(resp.status, body)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
@@ -731,27 +730,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=f"HTTP {resp.status} ({retry_label}, will retry in {wait_time:.1f}s)",
error_message=msg,
)
await sleep_with_interrupt(
wait_time,
cfg.node_cls,
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
)
continue
msg = _friendly_http_message(resp.status, body)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=msg,
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
raise Exception(msg)
if expect_binary:
@@ -771,14 +753,17 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
bytes_payload = bytes(buff)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=bytes_payload,
)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=bytes_payload,
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
return bytes_payload
else:
try:
@@ -795,39 +780,45 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=response_content_to_log,
)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=response_content_to_log,
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
return payload
except ProcessingInterrupted:
logging.debug("Polling was interrupted by user")
raise
except (ClientError, OSError) as e:
if (attempt - rate_limit_attempts) <= cfg.max_retries:
if attempt <= cfg.max_retries:
logging.warning(
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
method,
url,
delay,
attempt - rate_limit_attempts,
attempt,
cfg.max_retries,
str(e),
)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
except Exception as _log_e:
logging.debug("[DEBUG] request error logging failed: %s", _log_e)
await sleep_with_interrupt(
delay,
cfg.node_cls,
@@ -840,6 +831,23 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
continue
diag = await _diagnose_connectivity()
if not diag["internet_accessible"]:
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"LocalNetworkError: {str(e)}",
)
except Exception as _log_e:
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
raise LocalNetworkError(
"Unable to connect to the API server due to local network issues. "
"Please check your internet connection and try again."
) from e
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
@@ -847,21 +855,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"LocalNetworkError: {str(e)}",
error_message=f"ApiServerError: {str(e)}",
)
raise LocalNetworkError(
"Unable to connect to the API server due to local network issues. "
"Please check your internet connection and try again."
) from e
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"ApiServerError: {str(e)}",
)
except Exception as _log_e:
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
raise ApiServerError(
f"The API server at {default_base_url()} is currently unreachable. "
f"The service may be experiencing issues."

View File

@@ -167,25 +167,27 @@ async def download_url_to_bytesio(
with contextlib.suppress(Exception):
dest.seek(0)
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=f"[streamed {written} bytes to dest]",
)
with contextlib.suppress(Exception):
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=f"[streamed {written} bytes to dest]",
)
return
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
except (ClientError, OSError) as e:
if attempt <= max_retries:
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
with contextlib.suppress(Exception):
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
await sleep_with_interrupt(delay, cls, None, None, None)
delay *= retry_backoff
continue

View File

@@ -8,6 +8,7 @@ from typing import Any
import folder_paths
# Get the logger instance
logger = logging.getLogger(__name__)
@@ -90,41 +91,38 @@ def log_request_response(
Filenames are sanitized and length-limited for cross-platform safety.
If we still fail to write, we fall back to appending into api.log.
"""
log_dir = get_log_directory()
filepath = _build_log_filepath(log_dir, operation_id, request_url)
log_content: list[str] = []
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
log_content.append(f"Operation ID: {operation_id}")
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
log_content.append(f"Method: {request_method}")
log_content.append(f"URL: {request_url}")
if request_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
if request_params:
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
if request_data is not None:
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
if response_status_code is not None:
log_content.append(f"Status Code: {response_status_code}")
if response_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
if response_content is not None:
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
if error_message:
log_content.append(f"Error:\n{error_message}")
try:
log_dir = get_log_directory()
filepath = _build_log_filepath(log_dir, operation_id, request_url)
log_content: list[str] = []
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
log_content.append(f"Operation ID: {operation_id}")
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
log_content.append(f"Method: {request_method}")
log_content.append(f"URL: {request_url}")
if request_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
if request_params:
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
if request_data is not None:
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
if response_status_code is not None:
log_content.append(f"Status Code: {response_status_code}")
if response_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
if response_content is not None:
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
if error_message:
log_content.append(f"Error:\n{error_message}")
try:
with open(filepath, "w", encoding="utf-8") as f:
f.write("\n".join(log_content))
logger.debug("API log saved to: %s", filepath)
except Exception as e:
logger.error("Error writing API log to %s: %s", filepath, str(e))
except Exception as _log_e:
logging.debug("[DEBUG] log_request_response failed: %s", _log_e)
with open(filepath, "w", encoding="utf-8") as f:
f.write("\n".join(log_content))
logger.debug("API log saved to: %s", filepath)
except Exception as e:
logger.error("Error writing API log to %s: %s", filepath, str(e))
if __name__ == '__main__':

View File

@@ -255,14 +255,17 @@ async def upload_file(
monitor_task = asyncio.create_task(_monitor())
sess: aiohttp.ClientSession | None = None
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers or None,
request_params=None,
request_data=f"[File data {len(data)} bytes]",
)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers or None,
request_params=None,
request_data=f"[File data {len(data)} bytes]",
)
except Exception as e:
logging.debug("[DEBUG] upload request logging failed: %s", e)
sess = aiohttp.ClientSession(timeout=timeout)
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
@@ -308,27 +311,31 @@ async def upload_file(
delay *= retry_backoff
continue
raise Exception(f"Failed to upload (HTTP {resp.status}).")
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content="File uploaded successfully.",
)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content="File uploaded successfully.",
)
except Exception as e:
logging.debug("[DEBUG] upload response logging failed: %s", e)
return
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
except (aiohttp.ClientError, OSError) as e:
if attempt <= max_retries:
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers or None,
request_data=f"[File data {len(data)} bytes]",
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
with contextlib.suppress(Exception):
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers or None,
request_data=f"[File data {len(data)} bytes]",
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
await sleep_with_interrupt(
delay,
cls,

View File

@@ -1,132 +0,0 @@
from __future__ import annotations
import hashlib
import os
import numpy as np
import torch
from PIL import Image
import folder_paths
import node_helpers
from comfy_api.latest import ComfyExtension, io, UI
from typing_extensions import override
def hex_to_rgb(hex_color: str) -> tuple[float, float, float]:
hex_color = hex_color.lstrip("#")
if len(hex_color) != 6:
return (0.0, 0.0, 0.0)
r = int(hex_color[0:2], 16) / 255.0
g = int(hex_color[2:4], 16) / 255.0
b = int(hex_color[4:6], 16) / 255.0
return (r, g, b)
class PainterNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Painter",
display_name="Painter",
category="image",
inputs=[
io.Image.Input(
"image",
optional=True,
tooltip="Optional base image to paint over",
),
io.String.Input(
"mask",
default="",
socketless=True,
extra_dict={"widgetType": "PAINTER", "image_upload": True},
),
io.Int.Input(
"width",
default=512,
min=64,
max=4096,
step=64,
socketless=True,
extra_dict={"hidden": True},
),
io.Int.Input(
"height",
default=512,
min=64,
max=4096,
step=64,
socketless=True,
extra_dict={"hidden": True},
),
io.String.Input(
"bg_color",
default="#000000",
socketless=True,
extra_dict={"hidden": True, "widgetType": "COLOR"},
),
],
outputs=[
io.Image.Output("IMAGE"),
io.Mask.Output("MASK"),
],
)
@classmethod
def execute(cls, mask, width, height, bg_color="#000000", image=None) -> io.NodeOutput:
if image is not None:
h, w = image.shape[1], image.shape[2]
base_image = image
else:
h, w = height, width
r, g, b = hex_to_rgb(bg_color)
base_image = torch.zeros((1, h, w, 3), dtype=torch.float32)
base_image[0, :, :, 0] = r
base_image[0, :, :, 1] = g
base_image[0, :, :, 2] = b
if mask and mask.strip():
mask_path = folder_paths.get_annotated_filepath(mask)
painter_img = node_helpers.pillow(Image.open, mask_path)
painter_img = painter_img.convert("RGBA")
if painter_img.size != (w, h):
painter_img = painter_img.resize((w, h), Image.LANCZOS)
painter_np = np.array(painter_img).astype(np.float32) / 255.0
painter_rgb = painter_np[:, :, :3]
painter_alpha = painter_np[:, :, 3:4]
mask_tensor = torch.from_numpy(painter_np[:, :, 3]).unsqueeze(0)
base_np = base_image[0].cpu().numpy()
composited = painter_rgb * painter_alpha + base_np * (1.0 - painter_alpha)
out_image = torch.from_numpy(composited).unsqueeze(0)
else:
mask_tensor = torch.zeros((1, h, w), dtype=torch.float32)
out_image = base_image
return io.NodeOutput(out_image, mask_tensor, ui=UI.PreviewImage(out_image))
@classmethod
def fingerprint_inputs(cls, mask, width, height, bg_color="#000000", image=None):
if mask and mask.strip():
mask_path = folder_paths.get_annotated_filepath(mask)
if os.path.exists(mask_path):
m = hashlib.sha256()
with open(mask_path, "rb") as f:
m.update(f.read())
return m.digest().hex()
return ""
class PainterExtension(ComfyExtension):
@override
async def get_node_list(self):
return [PainterNode]
async def comfy_entrypoint():
return PainterExtension()

View File

@@ -1035,7 +1035,7 @@ class TrainLoraNode(io.ComfyNode):
io.Boolean.Input(
"offloading",
default=False,
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
tooltip="Depth level for gradient checkpointing.",
),
io.Combo.Input(
"existing_lora",
@@ -1124,15 +1124,6 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
if mp.is_dynamic():
if not bypass_mode:
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
bypass_mode = True
offloading = True
elif offloading:
if not bypass_mode:
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
# Prepare latents and compute counts
latents, num_images, multi_res = _prepare_latents_and_count(
latents, dtype, bucket_mode

View File

@@ -2435,7 +2435,6 @@ async def init_builtin_extra_nodes():
"nodes_lora_debug.py",
"nodes_color.py",
"nodes_toolkit.py",
"nodes_painter.py",
]
import_failed = []

10
requirements-amd.txt Normal file
View File

@@ -0,0 +1,10 @@
# AMD GPU requirements (ROCm)
# Usage: pip install -r requirements-amd.txt
#
# Note: This is for AMD GPUs with ROCm support.
# For experimental Windows/Linux support on RDNA 3/3.5/4, see README.md
--index-url https://download.pytorch.org/whl/rocm7.1
--extra-index-url https://pypi.org/simple
-r requirements.txt

7
requirements-intel.txt Normal file
View File

@@ -0,0 +1,7 @@
# Intel GPU requirements (XPU - Arc GPUs)
# Usage: pip install -r requirements-intel.txt
--index-url https://download.pytorch.org/whl/xpu
--extra-index-url https://pypi.org/simple
-r requirements.txt

6
requirements-nvidia.txt Normal file
View File

@@ -0,0 +1,6 @@
# NVIDIA GPU requirements (CUDA 13.0)
# Usage: pip install -r requirements-nvidia.txt
--extra-index-url https://download.pytorch.org/whl/cu130
-r requirements.txt