Compare commits

..

16 Commits

Author SHA1 Message Date
comfyanonymous
8affde131f Fix flux controlnet. 2026-02-12 22:47:27 -05:00
comfyanonymous
edde057369 Fix blf control loras. 2026-02-12 22:40:49 -05:00
comfyanonymous
ca0c349005 Fix 2026-02-12 22:23:21 -05:00
comfyanonymous
3f9800b33a Add process_unet_state_dict method to handle state dict 2026-02-12 22:02:33 -05:00
comfyanonymous
6828021606 Remove unused import in layers.py
Removed unused import of comfy.ops.
2026-02-12 18:01:44 -05:00
comfyanonymous
75e22eb72e Remove unused import in layers.py
Removed unused import of common_dit from layers.py
2026-02-12 17:59:18 -05:00
comfyanonymous
c61f69acf5 Use torch RMSNorm for flux models and refactor hunyuan video code. 2026-02-12 17:57:19 -05:00
Alexander Piskun
4a93a62371 fix(api-nodes): add separate retry budget for 429 rate limit responses (#12421) 2026-02-12 01:38:51 -08:00
comfyanonymous
66c18522fb Add a tip for common error. (#12414) 2026-02-11 22:12:16 -05:00
askmyteapot
e5ae670a40 Update ace15.py to allow min_p sampling (#12373) 2026-02-11 20:28:48 -05:00
rattus
3fe61cedda model_patcher: guard against none model_dtype (#12410)
Handle the case where the _model_dtype exists but is none with the
intended fallback.
2026-02-11 14:54:02 -05:00
rattus
2a4328d639 ace15: Use dynamic_vram friendly trange (#12409)
Factor out the ksampler trange and use it in ACE LLM to prevent the
silent stall at 0 and rate distortion due to first-step model load.
2026-02-11 14:53:42 -05:00
rattus
d297a749a2 dynamic_vram: Fix windows Aimdo crash + Fix LLM performance (#12408)
* model_management: lazy-cache aimdo_tensor

These tensors cosntructed from aimdo-allocations are CPU expensive to
make on the pytorch side. Add a cache version that will be valid with
signature match to fast path past whatever torch is doing.

* dynamic_vram: Minimize fast path CPU work

Move as much as possible inside the not resident if block and cache
the formed weight and bias rather than the flat intermediates. In
extreme layer weight rates this adds up.
2026-02-11 14:50:16 -05:00
Alexander Piskun
2b7cc7e3b6 [API Nodes] enable Magnific Upscalers (#12179)
* feat(api-nodes): enable Magnific Upscalers

* update price badges

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-11 11:30:19 -08:00
Benjamin Lu
4993411fd9 Dispatch desktop auto-bump when a ComfyUI release is published (#12398)
* Dispatch desktop auto-bump on ComfyUI release publish

* Fix release webhook secret checks in step conditions

* Require desktop dispatch token in release webhook

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Luke Mino-Altherr <lminoaltherr@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-02-11 11:15:13 -08:00
Alexander Piskun
2c7cef4a23 fix(api-nodes): retry on connection errors during polling instead of aborting (#12393) 2026-02-11 10:51:49 -08:00
24 changed files with 415 additions and 314 deletions

View File

@@ -7,6 +7,8 @@ on:
jobs:
send-webhook:
runs-on: ubuntu-latest
env:
DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
steps:
- name: Send release webhook
env:
@@ -106,3 +108,37 @@ jobs:
--fail --silent --show-error
echo "✅ Release webhook sent successfully"
- name: Send repository dispatch to desktop
env:
DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
RELEASE_TAG: ${{ github.event.release.tag_name }}
RELEASE_URL: ${{ github.event.release.html_url }}
run: |
set -euo pipefail
if [ -z "${DISPATCH_TOKEN:-}" ]; then
echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
exit 1
fi
PAYLOAD="$(jq -n \
--arg release_tag "$RELEASE_TAG" \
--arg release_url "$RELEASE_URL" \
'{
event_type: "comfyui_release_published",
client_payload: {
release_tag: $release_tag,
release_url: $release_url
}
}')"
curl -fsSL \
-X POST \
-H "Accept: application/vnd.github+json" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
https://api.github.com/repos/Comfy-Org/desktop/dispatches \
-d "$PAYLOAD"
echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"

View File

@@ -560,6 +560,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
sd = model_config.process_unet_state_dict(sd)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)

View File

@@ -1,12 +1,11 @@
import math
import time
from functools import partial
from scipy import integrate
import torch
from torch import nn
import torchsde
from tqdm.auto import trange as trange_, tqdm
from tqdm.auto import tqdm
from . import utils
from . import deis
@@ -15,34 +14,7 @@ import comfy.model_patcher
import comfy.model_sampling
import comfy.memory_management
def trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None:
return trange_(*args, **kwargs)
pbar = trange_(*args, **kwargs, smoothing=1.0)
pbar._i = 0
pbar.set_postfix_str(" Model Initializing ... ")
_update = pbar.update
def warmup_update(n=1):
pbar._i += 1
if pbar._i == 1:
pbar.i1_time = time.time()
pbar.set_postfix_str(" Model Initialization complete! ")
elif pbar._i == 2:
#bring forward the effective start time based the the diff between first and second iteration
#to attempt to remove load overhead from the final step rate estimate.
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
pbar.set_postfix_str("")
_update(n)
pbar.update = warmup_update
return pbar
from comfy.utils import model_trange as trange
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])

View File

@@ -3,7 +3,6 @@ from torch import Tensor, nn
from comfy.ldm.flux.layers import (
MLPEmbedder,
RMSNorm,
ModulationOut,
)
@@ -29,7 +28,7 @@ class Approximator(nn.Module):
super().__init__()
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
@property

View File

@@ -4,8 +4,6 @@ from functools import lru_cache
import torch
from torch import nn
from comfy.ldm.flux.layers import RMSNorm
class NerfEmbedder(nn.Module):
"""
@@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module):
# We now need to generate parameters for 3 matrices.
total_params = 3 * hidden_size_x**2 * mlp_ratio
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
self.mlp_ratio = mlp_ratio
@@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module):
class NerfFinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module):
class NerfFinalLayerConv(nn.Module):
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.conv = operations.Conv2d(
in_channels=hidden_size,
out_channels=out_channels,

View File

@@ -5,8 +5,6 @@ import torch
from torch import Tensor, nn
from .math import attention, rope
import comfy.ops
import comfy.ldm.common_dit
class EmbedND(nn.Module):
@@ -87,20 +85,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor):
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
class QKNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
q = self.query_norm(q)
@@ -169,7 +159,7 @@ class SiLUActivation(nn.Module):
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@@ -197,8 +187,6 @@ class DoubleStreamBlock(nn.Module):
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
if self.modulation:
img_mod1, img_mod2 = self.img_mod(vec)
@@ -224,32 +212,17 @@ class DoubleStreamBlock(nn.Module):
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
if self.flipped_img_txt:
q = torch.cat((img_q, txt_q), dim=2)
del img_q, txt_q
k = torch.cat((img_k, txt_k), dim=2)
del img_k, txt_k
v = torch.cat((img_v, txt_v), dim=2)
del img_v, txt_v
# run actual attention
attn = attention(q, k, v,
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# run actual attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# run actual attention
attn = attention(q, k, v,
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)

View File

@@ -16,7 +16,6 @@ from .layers import (
SingleStreamBlock,
timestep_embedding,
Modulation,
RMSNorm
)
@dataclass
@@ -81,7 +80,7 @@ class Flux(nn.Module):
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
if params.txt_norm:
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
else:
self.txt_norm = None

View File

@@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
flipped_img_txt=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -378,14 +377,14 @@ class HunyuanVideo(nn.Module):
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
ids = torch.cat((img_ids, txt_ids), dim=1)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
img_len = img.shape[1]
if txt_mask is not None:
attn_mask_len = img_len + txt.shape[1]
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
attn_mask[:, 0, img_len:] = txt_mask
attn_mask[:, 0, :txt.shape[1]] = txt_mask
else:
attn_mask = None
@@ -413,7 +412,7 @@ class HunyuanVideo(nn.Module):
if add is not None:
img += add
img = torch.cat((img, txt), 1)
img = torch.cat((txt, img), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
@@ -435,9 +434,9 @@ class HunyuanVideo(nn.Module):
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, : img_len] += add
img[:, txt.shape[1]: img_len + txt.shape[1]] += add
img = img[:, : img_len]
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
if ref_latent is not None:
img = img[:, ref_latent.shape[1]:]

View File

@@ -5,7 +5,7 @@ import comfy.utils
def convert_lora_bfl_control(sd): #BFL loras for Flux
sd_out = {}
for k in sd:
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight"))
sd_out[k_to] = sd[k]
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])

View File

@@ -19,6 +19,12 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1
return count
def any_suffix_in(keys, prefix, main, suffix_list=[]):
for x in suffix_list:
if "{}{}{}".format(prefix, main, x) in keys:
return True
return False
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
context_dim = None
use_linear_in_transformer = False
@@ -186,7 +192,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["meanflow_sum"] = False
return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
dit_config = {}
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
dit_config["image_model"] = "flux2"
@@ -241,7 +247,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma
dit_config["image_model"] = "chroma"
dit_config["in_channels"] = 64
dit_config["out_channels"] = 64
@@ -249,7 +256,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["out_dim"] = 3072
dit_config["hidden_dim"] = 5120
dit_config["n_layers"] = 5
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance
dit_config["image_model"] = "chroma_radiance"
dit_config["in_channels"] = 3
dit_config["out_channels"] = 3
@@ -259,7 +267,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_depth"] = 4
dit_config["nerf_max_freqs"] = 8
dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
@@ -268,7 +276,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
dit_config["txt_ids_dims"] = [1, 2]

View File

@@ -1213,8 +1213,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, weight._v_tensor)[0]
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
v_tensor = weight._v_tensor
else:
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
weight._v_tensor = v_tensor
weight._v_signature = signature
#Send it over
v_tensor.copy_(weight, non_blocking=non_blocking)

View File

@@ -1525,7 +1525,7 @@ class ModelPatcherDynamic(ModelPatcher):
setattr(m, param_key + "_function", weight_function)
geometry = weight
if not isinstance(weight, QuantizedTensor):
model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype)
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
weight._model_dtype = model_dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
return comfy.memory_management.vram_aligned_size(geometry)
@@ -1542,7 +1542,6 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size)
m._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(m._v, device_to)
allocated_size += v_weight_size
else:
@@ -1552,12 +1551,11 @@ class ModelPatcherDynamic(ModelPatcher):
weight.seed_key = key
set_dirty(weight, dirty)
geometry = weight
model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype)
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
weight_size = geometry.numel() * geometry.element_size()
if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size)
weight._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device_to)
weight._model_dtype = model_dtype
allocated_size += weight_size
vbar.set_watermark_limit(allocated_size)

View File

@@ -83,14 +83,18 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
offload_stream = None
xfer_dest = None
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
if signature is not None:
xfer_dest = s._v_tensor
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if signature is not None:
if resident:
weight = s._v_weight
bias = s._v_bias
else:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
if not resident:
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None
xfer_source = [ s.weight, s.bias ]
@@ -140,9 +144,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
post_cast.copy_(pre_cast)
xfer_dest = cast_dest
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
weight = params[0]
bias = params[1]
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
weight = params[0]
bias = params[1]
if signature is not None:
s._v_weight = weight
s._v_bias = bias
s._v_signature=signature
def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
@@ -182,7 +190,6 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
if s.bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
s._v_signature=signature
#FIXME: weird offload return protocol
return weight, bias, (offload_stream, device if signature is not None else None, None)

View File

@@ -710,6 +710,15 @@ class Flux(supported_models_base.BASE):
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
if key_out.endswith("_norm.scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
@@ -898,11 +907,13 @@ class HunyuanVideo(supported_models_base.BASE):
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight")
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight")
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
if key_out.endswith(".scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
@@ -1264,6 +1275,15 @@ class Hunyuan3Dv2(supported_models_base.BASE):
latent_format = latent_formats.Hunyuan3Dv2
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
if key_out.endswith(".scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
@@ -1341,6 +1361,14 @@ class Chroma(supported_models_base.BASE):
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
if key_out.endswith(".scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Chroma(self, device=device)

View File

@@ -3,7 +3,6 @@ import comfy.text_encoders.llama
from comfy import sd1_clip
import torch
import math
from tqdm.auto import trange
import yaml
import comfy.utils
@@ -17,6 +16,7 @@ def sample_manual_loop_no_classes(
temperature: float = 0.85,
top_p: float = 0.9,
top_k: int = None,
min_p: float = 0.000,
seed: int = 1,
min_tokens: int = 1,
max_new_tokens: int = 2048,
@@ -52,7 +52,7 @@ def sample_manual_loop_no_classes(
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
for step in trange(max_new_tokens, desc="LM sampling"):
for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"):
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
past_key_values = outputs[2]
@@ -81,6 +81,12 @@ def sample_manual_loop_no_classes(
min_val = top_k_vals[..., -1, None]
cfg_logits[cfg_logits < min_val] = remove_logit_value
if min_p is not None and min_p > 0:
probs = torch.softmax(cfg_logits, dim=-1)
p_max = probs.max(dim=-1, keepdim=True).values
indices_to_remove = probs < (min_p * p_max)
cfg_logits[indices_to_remove] = remove_logit_value
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
@@ -111,7 +117,7 @@ def sample_manual_loop_no_classes(
return output_audio_codes
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0, min_p=0.000):
positive = [[token for token, _ in inner_list] for inner_list in positive]
positive = positive[0]
@@ -135,7 +141,7 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
paddings = []
ids = [positive]
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
@@ -193,6 +199,7 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
temperature = kwargs.get("temperature", 0.85)
top_p = kwargs.get("top_p", 0.9)
top_k = kwargs.get("top_k", 0.0)
min_p = kwargs.get("min_p", 0.000)
duration = math.ceil(duration)
kwargs["duration"] = duration
@@ -240,6 +247,7 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"min_p": min_p,
}
return out
@@ -300,7 +308,7 @@ class ACE15TEModel(torch.nn.Module):
lm_metadata = token_weight_pairs["lm_metadata"]
if lm_metadata["generate_audio_codes"]:
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["max_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"], min_p=lm_metadata["min_p"])
out["audio_codes"] = [audio_codes]
return base_out, None, out

View File

@@ -27,6 +27,7 @@ from PIL import Image
import logging
import itertools
from torch.nn.functional import interpolate
from tqdm.auto import trange
from einops import rearrange
from comfy.cli_args import args, enables_dynamic_vram
import json
@@ -674,10 +675,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"ff_context.linear_in.bias": "txt_mlp.0.bias",
"ff_context.linear_out.weight": "txt_mlp.2.weight",
"ff_context.linear_out.bias": "txt_mlp.2.bias",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
"attn.norm_q.weight": "img_attn.norm.query_norm.weight",
"attn.norm_k.weight": "img_attn.norm.key_norm.weight",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight",
}
for k in block_map:
@@ -700,8 +701,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"norm.linear.bias": "modulation.lin.bias",
"proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias",
"attn.norm_q.weight": "norm.query_norm.scale",
"attn.norm_k.weight": "norm.key_norm.scale",
"attn.norm_q.weight": "norm.query_norm.weight",
"attn.norm_k.weight": "norm.key_norm.weight",
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
"attn.to_out.weight": "linear2.weight", # Flux 2
}
@@ -1155,6 +1156,32 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
def model_trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None:
return trange(*args, **kwargs)
pbar = trange(*args, **kwargs, smoothing=1.0)
pbar._i = 0
pbar.set_postfix_str(" Model Initializing ... ")
_update = pbar.update
def warmup_update(n=1):
pbar._i += 1
if pbar._i == 1:
pbar.i1_time = time.time()
pbar.set_postfix_str(" Model Initialization complete! ")
elif pbar._i == 2:
#bring forward the effective start time based the the diff between first and second iteration
#to attempt to remove load overhead from the final step rate estimate.
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
pbar.set_postfix_str("")
_update(n)
pbar.update = warmup_update
return pbar
PROGRESS_BAR_ENABLED = True
def set_progress_bar_enabled(enabled):
global PROGRESS_BAR_ENABLED

View File

@@ -30,6 +30,30 @@ from comfy_api_nodes.util import (
validate_image_dimensions,
)
_EUR_TO_USD = 1.19
def _tier_price_eur(megapixels: float) -> float:
"""Price in EUR for a single Magnific upscaling step based on input megapixels."""
if megapixels <= 1.3:
return 0.143
if megapixels <= 3.0:
return 0.286
if megapixels <= 6.4:
return 0.429
return 1.716
def _calculate_magnific_upscale_price_usd(width: int, height: int, scale: int) -> float:
"""Calculate total Magnific upscale price in USD for given input dimensions and scale factor."""
num_steps = int(math.log2(scale))
total_eur = 0.0
pixels = width * height
for _ in range(num_steps):
total_eur += _tier_price_eur(pixels / 1_000_000)
pixels *= 4
return round(total_eur * _EUR_TO_USD, 2)
class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
@classmethod
@@ -103,11 +127,20 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor", "auto_downscale"]),
expr="""
(
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
$ad := widgets.auto_downscale;
$mins := $ad
? {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.515}
: {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
$maxs := {"2x": 0.515, "4x": 0.844, "8x": 1.015, "16x": 1.187};
{
"type": "range_usd",
"min_usd": $lookup($mins, widgets.scale_factor),
"max_usd": $lookup($maxs, widgets.scale_factor),
"format": { "approximate": true }
}
)
""",
),
@@ -168,6 +201,10 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
f"Use a smaller input image or lower scale factor."
)
final_height, final_width = get_image_dimensions(image)
actual_scale = int(scale_factor.rstrip("x"))
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, actual_scale)
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"),
@@ -189,6 +226,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"),
response_model=TaskResponse,
status_extractor=lambda x: x.status,
price_extractor=lambda _: price_usd,
poll_interval=10.0,
max_poll_attempts=480,
)
@@ -257,8 +295,14 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
expr="""
(
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
$mins := {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
$maxs := {"2x": 2.045, "4x": 2.545, "8x": 2.889, "16x": 3.06};
{
"type": "range_usd",
"min_usd": $lookup($mins, widgets.scale_factor),
"max_usd": $lookup($maxs, widgets.scale_factor),
"format": { "approximate": true }
}
)
""",
),
@@ -321,6 +365,9 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
f"Use a smaller input image or lower scale factor."
)
final_height, final_width = get_image_dimensions(image)
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, requested_scale)
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"),
@@ -339,6 +386,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"),
response_model=TaskResponse,
status_extractor=lambda x: x.status,
price_extractor=lambda _: price_usd,
poll_interval=10.0,
max_poll_attempts=480,
)
@@ -877,8 +925,8 @@ class MagnificExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
# MagnificImageUpscalerCreativeNode,
# MagnificImageUpscalerPreciseV2Node,
MagnificImageUpscalerCreativeNode,
MagnificImageUpscalerPreciseV2Node,
MagnificImageStyleTransferNode,
MagnificImageRelightNode,
MagnificImageSkinEnhancerNode,

View File

@@ -57,6 +57,7 @@ class _RequestConfig:
files: dict[str, Any] | list[tuple[str, Any]] | None
multipart_parser: Callable | None
max_retries: int
max_retries_on_rate_limit: int
retry_delay: float
retry_backoff: float
wait_label: str = "Waiting"
@@ -65,6 +66,7 @@ class _RequestConfig:
final_label_on_success: str | None = "Completed"
progress_origin_ts: float | None = None
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
is_rate_limited: Callable[[int, Any], bool] | None = None
@dataclass
@@ -78,7 +80,7 @@ class _PollUIState:
active_since: float | None = None # start time of current active interval (None if queued)
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
@@ -103,6 +105,8 @@ async def sync_op(
final_label_on_success: str | None = "Completed",
progress_origin_ts: float | None = None,
monitor_progress: bool = True,
max_retries_on_rate_limit: int = 16,
is_rate_limited: Callable[[int, Any], bool] | None = None,
) -> M:
raw = await sync_op_raw(
cls,
@@ -122,6 +126,8 @@ async def sync_op(
final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts,
monitor_progress=monitor_progress,
max_retries_on_rate_limit=max_retries_on_rate_limit,
is_rate_limited=is_rate_limited,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
@@ -143,9 +149,9 @@ async def poll_op(
poll_interval: float = 5.0,
max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
max_retries_per_poll: int = 10,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
retry_backoff_per_poll: float = 1.4,
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
@@ -194,6 +200,8 @@ async def sync_op_raw(
final_label_on_success: str | None = "Completed",
progress_origin_ts: float | None = None,
monitor_progress: bool = True,
max_retries_on_rate_limit: int = 16,
is_rate_limited: Callable[[int, Any], bool] | None = None,
) -> dict[str, Any] | bytes:
"""
Make a single network request.
@@ -222,6 +230,8 @@ async def sync_op_raw(
final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts,
price_extractor=price_extractor,
max_retries_on_rate_limit=max_retries_on_rate_limit,
is_rate_limited=is_rate_limited,
)
return await _request_base(cfg, expect_binary=as_binary)
@@ -240,9 +250,9 @@ async def poll_op_raw(
poll_interval: float = 5.0,
max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
max_retries_per_poll: int = 10,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
retry_backoff_per_poll: float = 1.4,
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
@@ -506,7 +516,7 @@ def _friendly_http_message(status: int, body: Any) -> str:
if status == 409:
return "There is a problem with your account. Please contact support@comfy.org."
if status == 429:
return "Rate Limit Exceeded: Please try again later."
return "Rate Limit Exceeded: The server returned 429 after all retry attempts. Please wait and try again."
try:
if isinstance(body, dict):
err = body.get("error")
@@ -586,6 +596,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
attempt = 0
delay = cfg.retry_delay
rate_limit_attempts = 0
rate_limit_delay = cfg.retry_delay
operation_succeeded: bool = False
final_elapsed_seconds: int | None = None
extracted_price: float | None = None
@@ -653,17 +665,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload_headers["Content-Type"] = "application/json"
payload_kw["json"] = cfg.data or {}
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
)
except Exception as _log_e:
logging.debug("[DEBUG] request logging failed: %s", _log_e)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
)
req_coro = sess.request(method, url, params=params, **payload_kw)
req_task = asyncio.create_task(req_coro)
@@ -688,41 +697,33 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
body = await resp.json()
except (ContentTypeError, json.JSONDecodeError):
body = await resp.text()
if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
should_retry = False
wait_time = 0.0
retry_label = ""
is_rl = resp.status == 429 or (
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
)
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
rate_limit_attempts += 1
wait_time = min(rate_limit_delay, 30.0)
rate_limit_delay *= cfg.retry_backoff
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
should_retry = True
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
wait_time = delay
delay *= cfg.retry_backoff
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
should_retry = True
if should_retry:
logging.warning(
"HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
"HTTP %s %s -> %s. Waiting %.2fs (%s).",
method,
url,
resp.status,
delay,
attempt,
cfg.max_retries,
wait_time,
retry_label,
)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=_friendly_http_message(resp.status, body),
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
await sleep_with_interrupt(
delay,
cfg.node_cls,
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
)
delay *= cfg.retry_backoff
continue
msg = _friendly_http_message(resp.status, body)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
@@ -730,10 +731,27 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=msg,
error_message=f"HTTP {resp.status} ({retry_label}, will retry in {wait_time:.1f}s)",
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
await sleep_with_interrupt(
wait_time,
cfg.node_cls,
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
)
continue
msg = _friendly_http_message(resp.status, body)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=msg,
)
raise Exception(msg)
if expect_binary:
@@ -753,17 +771,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
bytes_payload = bytes(buff)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=bytes_payload,
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=bytes_payload,
)
return bytes_payload
else:
try:
@@ -780,45 +795,39 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=response_content_to_log,
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=response_content_to_log,
)
return payload
except ProcessingInterrupted:
logging.debug("Polling was interrupted by user")
raise
except (ClientError, OSError) as e:
if attempt <= cfg.max_retries:
if (attempt - rate_limit_attempts) <= cfg.max_retries:
logging.warning(
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
method,
url,
delay,
attempt,
attempt - rate_limit_attempts,
cfg.max_retries,
str(e),
)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
except Exception as _log_e:
logging.debug("[DEBUG] request error logging failed: %s", _log_e)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
await sleep_with_interrupt(
delay,
cfg.node_cls,
@@ -831,23 +840,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
continue
diag = await _diagnose_connectivity()
if not diag["internet_accessible"]:
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"LocalNetworkError: {str(e)}",
)
except Exception as _log_e:
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
raise LocalNetworkError(
"Unable to connect to the API server due to local network issues. "
"Please check your internet connection and try again."
) from e
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
@@ -855,10 +847,21 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"ApiServerError: {str(e)}",
error_message=f"LocalNetworkError: {str(e)}",
)
except Exception as _log_e:
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
raise LocalNetworkError(
"Unable to connect to the API server due to local network issues. "
"Please check your internet connection and try again."
) from e
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"ApiServerError: {str(e)}",
)
raise ApiServerError(
f"The API server at {default_base_url()} is currently unreachable. "
f"The service may be experiencing issues."

View File

@@ -167,27 +167,25 @@ async def download_url_to_bytesio(
with contextlib.suppress(Exception):
dest.seek(0)
with contextlib.suppress(Exception):
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=f"[streamed {written} bytes to dest]",
)
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=f"[streamed {written} bytes to dest]",
)
return
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
except (ClientError, OSError) as e:
if attempt <= max_retries:
with contextlib.suppress(Exception):
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
await sleep_with_interrupt(delay, cls, None, None, None)
delay *= retry_backoff
continue

View File

@@ -8,7 +8,6 @@ from typing import Any
import folder_paths
# Get the logger instance
logger = logging.getLogger(__name__)
@@ -91,38 +90,41 @@ def log_request_response(
Filenames are sanitized and length-limited for cross-platform safety.
If we still fail to write, we fall back to appending into api.log.
"""
log_dir = get_log_directory()
filepath = _build_log_filepath(log_dir, operation_id, request_url)
log_content: list[str] = []
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
log_content.append(f"Operation ID: {operation_id}")
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
log_content.append(f"Method: {request_method}")
log_content.append(f"URL: {request_url}")
if request_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
if request_params:
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
if request_data is not None:
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
if response_status_code is not None:
log_content.append(f"Status Code: {response_status_code}")
if response_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
if response_content is not None:
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
if error_message:
log_content.append(f"Error:\n{error_message}")
try:
with open(filepath, "w", encoding="utf-8") as f:
f.write("\n".join(log_content))
logger.debug("API log saved to: %s", filepath)
except Exception as e:
logger.error("Error writing API log to %s: %s", filepath, str(e))
log_dir = get_log_directory()
filepath = _build_log_filepath(log_dir, operation_id, request_url)
log_content: list[str] = []
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
log_content.append(f"Operation ID: {operation_id}")
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
log_content.append(f"Method: {request_method}")
log_content.append(f"URL: {request_url}")
if request_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
if request_params:
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
if request_data is not None:
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
if response_status_code is not None:
log_content.append(f"Status Code: {response_status_code}")
if response_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
if response_content is not None:
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
if error_message:
log_content.append(f"Error:\n{error_message}")
try:
with open(filepath, "w", encoding="utf-8") as f:
f.write("\n".join(log_content))
logger.debug("API log saved to: %s", filepath)
except Exception as e:
logger.error("Error writing API log to %s: %s", filepath, str(e))
except Exception as _log_e:
logging.debug("[DEBUG] log_request_response failed: %s", _log_e)
if __name__ == '__main__':

View File

@@ -255,17 +255,14 @@ async def upload_file(
monitor_task = asyncio.create_task(_monitor())
sess: aiohttp.ClientSession | None = None
try:
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers or None,
request_params=None,
request_data=f"[File data {len(data)} bytes]",
)
except Exception as e:
logging.debug("[DEBUG] upload request logging failed: %s", e)
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers or None,
request_params=None,
request_data=f"[File data {len(data)} bytes]",
)
sess = aiohttp.ClientSession(timeout=timeout)
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
@@ -311,31 +308,27 @@ async def upload_file(
delay *= retry_backoff
continue
raise Exception(f"Failed to upload (HTTP {resp.status}).")
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content="File uploaded successfully.",
)
except Exception as e:
logging.debug("[DEBUG] upload response logging failed: %s", e)
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content="File uploaded successfully.",
)
return
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
except (aiohttp.ClientError, OSError) as e:
if attempt <= max_retries:
with contextlib.suppress(Exception):
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers or None,
request_data=f"[File data {len(data)} bytes]",
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers or None,
request_data=f"[File data {len(data)} bytes]",
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
await sleep_with_interrupt(
delay,
cls,

View File

@@ -49,13 +49,14 @@ class TextEncodeAceStepAudio15(io.ComfyNode):
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
],
outputs=[io.Conditioning.Output()],
)
@classmethod
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k)
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
conditioning = clip.encode_from_tokens_scheduled(tokens)
return io.NodeOutput(conditioning)

View File

@@ -623,6 +623,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()
elif isinstance(ex, RuntimeError) and ("mat1 and mat2 shapes" in str(ex)) and "Sampler" in class_type:
tips = "\n\nTIPS: If you have any \"Load CLIP\" or \"*CLIP Loader\" nodes in your workflow connected to this sampler node make sure the correct file(s) and type is selected."
error_details = {
"node_id": real_node_id,

View File

@@ -30,6 +30,3 @@ kornia>=0.7.1
spandrel
pydantic~=2.0
pydantic-settings~=2.0
# test
fastapi