mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-15 12:40:03 +00:00
Compare commits
2 Commits
painter-no
...
jk/require
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5907e21dbc | ||
|
|
2058a3a69c |
@@ -227,7 +227,7 @@ Put your VAE in: models/vae
|
||||
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
||||
|
||||
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
|
||||
|
||||
|
||||
@@ -297,30 +297,6 @@ class ControlNet(ControlBase):
|
||||
self.model_sampling_current = None
|
||||
super().cleanup()
|
||||
|
||||
|
||||
class QwenFunControlNet(ControlNet):
|
||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||
# Fun checkpoints are more sensitive to high strengths in the generic
|
||||
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
|
||||
# unchanged while >1 grows more gently.
|
||||
original_strength = self.strength
|
||||
self.strength = math.sqrt(max(self.strength, 0.0))
|
||||
try:
|
||||
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||
finally:
|
||||
self.strength = original_strength
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
self.set_extra_arg("base_model", model.diffusion_model)
|
||||
|
||||
def copy(self):
|
||||
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c.control_model = self.control_model
|
||||
c.control_model_wrapped = self.control_model_wrapped
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
class ControlLoraOps:
|
||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
@@ -584,7 +560,6 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
sd = model_config.process_unet_state_dict(sd)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
@@ -630,53 +605,6 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet_qwen_fun(sd, model_options={}):
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
unet_dtype = model_options.get("dtype", weight_dtype)
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||
|
||||
in_features = sd["control_img_in.weight"].shape[1]
|
||||
inner_dim = sd["control_img_in.weight"].shape[0]
|
||||
|
||||
block_weight = sd["control_blocks.0.attn.to_q.weight"]
|
||||
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
|
||||
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
|
||||
|
||||
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
|
||||
control_in_features=in_features,
|
||||
inner_dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_control_blocks=5,
|
||||
main_model_double=60,
|
||||
injection_layers=(0, 12, 24, 36, 48),
|
||||
operations=operations,
|
||||
device=comfy.model_management.unet_offload_device(),
|
||||
dtype=unet_dtype,
|
||||
)
|
||||
model = controlnet_load_state_dict(model, sd)
|
||||
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
control = QwenFunControlNet(
|
||||
model,
|
||||
compression_ratio=1,
|
||||
latent_format=latent_format,
|
||||
# Fun checkpoints already expect their own 33-channel context handling.
|
||||
# Enabling generic concat_mask injects an extra mask channel at apply-time
|
||||
# and breaks the intended fallback packing path.
|
||||
concat_mask=False,
|
||||
load_device=load_device,
|
||||
manual_cast_dtype=manual_cast_dtype,
|
||||
extra_conds=[],
|
||||
)
|
||||
return control
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
@@ -754,8 +682,6 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
|
||||
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
|
||||
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||
|
||||
@@ -3,6 +3,7 @@ from torch import Tensor, nn
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
ModulationOut,
|
||||
)
|
||||
|
||||
@@ -28,7 +29,7 @@ class Approximator(nn.Module):
|
||||
super().__init__()
|
||||
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
|
||||
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||
|
||||
@property
|
||||
|
||||
@@ -4,6 +4,8 @@ from functools import lru_cache
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from comfy.ldm.flux.layers import RMSNorm
|
||||
|
||||
|
||||
class NerfEmbedder(nn.Module):
|
||||
"""
|
||||
@@ -143,7 +145,7 @@ class NerfGLUBlock(nn.Module):
|
||||
# We now need to generate parameters for 3 matrices.
|
||||
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
||||
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
||||
self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
|
||||
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
|
||||
@@ -176,7 +178,7 @@ class NerfGLUBlock(nn.Module):
|
||||
class NerfFinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -188,7 +190,7 @@ class NerfFinalLayer(nn.Module):
|
||||
class NerfFinalLayerConv(nn.Module):
|
||||
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.conv = operations.Conv2d(
|
||||
in_channels=hidden_size,
|
||||
out_channels=out_channels,
|
||||
|
||||
@@ -5,9 +5,9 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .math import attention, rope
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
# Fix import for some custom nodes, TODO: delete eventually.
|
||||
RMSNorm = None
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||
@@ -87,12 +87,20 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||
self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
||||
q = self.query_norm(q)
|
||||
@@ -161,7 +169,7 @@ class SiLUActivation(nn.Module):
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
@@ -189,6 +197,8 @@ class DoubleStreamBlock(nn.Module):
|
||||
|
||||
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||
if self.modulation:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
@@ -214,17 +224,32 @@ class DoubleStreamBlock(nn.Module):
|
||||
del txt_qkv
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
if self.flipped_img_txt:
|
||||
q = torch.cat((img_q, txt_q), dim=2)
|
||||
del img_q, txt_q
|
||||
k = torch.cat((img_k, txt_k), dim=2)
|
||||
del img_k, txt_k
|
||||
v = torch.cat((img_v, txt_v), dim=2)
|
||||
del img_v, txt_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
|
||||
@@ -16,6 +16,7 @@ from .layers import (
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
Modulation,
|
||||
RMSNorm
|
||||
)
|
||||
|
||||
@dataclass
|
||||
@@ -80,7 +81,7 @@ class Flux(nn.Module):
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||
|
||||
if params.txt_norm:
|
||||
self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
|
||||
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.txt_norm = None
|
||||
|
||||
|
||||
@@ -241,6 +241,7 @@ class HunyuanVideo(nn.Module):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
flipped_img_txt=True,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@@ -377,14 +378,14 @@ class HunyuanVideo(nn.Module):
|
||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
img_len = img.shape[1]
|
||||
if txt_mask is not None:
|
||||
attn_mask_len = img_len + txt.shape[1]
|
||||
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
||||
attn_mask[:, 0, :txt.shape[1]] = txt_mask
|
||||
attn_mask[:, 0, img_len:] = txt_mask
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
@@ -412,7 +413,7 @@ class HunyuanVideo(nn.Module):
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
img = torch.cat((img, txt), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
@@ -434,9 +435,9 @@ class HunyuanVideo(nn.Module):
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, txt.shape[1]: img_len + txt.shape[1]] += add
|
||||
img[:, : img_len] += add
|
||||
|
||||
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
|
||||
img = img[:, : img_len]
|
||||
if ref_latent is not None:
|
||||
img = img[:, ref_latent.shape[1]:]
|
||||
|
||||
|
||||
@@ -2,196 +2,6 @@ import torch
|
||||
import math
|
||||
|
||||
from .model import QwenImageTransformer2DModel
|
||||
from .model import QwenImageTransformerBlock
|
||||
|
||||
|
||||
class QwenImageFunControlBlock(QwenImageTransformerBlock):
|
||||
def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None):
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.has_before_proj = has_before_proj
|
||||
if has_before_proj:
|
||||
self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
|
||||
|
||||
class QwenImageFunControlNetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
control_in_features=132,
|
||||
inner_dim=3072,
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
num_control_blocks=5,
|
||||
main_model_double=60,
|
||||
injection_layers=(0, 12, 24, 36, 48),
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.main_model_double = main_model_double
|
||||
self.injection_layers = tuple(injection_layers)
|
||||
# Keep base hint scaling at 1.0 so user-facing strength behaves similarly
|
||||
# to the reference Gen2/VideoX implementation around strength=1.
|
||||
self.hint_scale = 1.0
|
||||
self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype)
|
||||
|
||||
self.control_blocks = torch.nn.ModuleList([])
|
||||
for i in range(num_control_blocks):
|
||||
self.control_blocks.append(
|
||||
QwenImageFunControlBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
has_before_proj=(i == 0),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
)
|
||||
|
||||
def _process_hint_tokens(self, hint):
|
||||
if hint is None:
|
||||
return None
|
||||
if hint.ndim == 4:
|
||||
hint = hint.unsqueeze(2)
|
||||
|
||||
# Fun checkpoints are trained with 33 latent channels before 2x2 packing:
|
||||
# [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features.
|
||||
# Default behavior (no inpaint input in stock Apply ControlNet) should use
|
||||
# zeros for mask/inpaint branches, matching VideoX fallback semantics.
|
||||
expected_c = self.control_img_in.weight.shape[1] // 4
|
||||
if hint.shape[1] == 16 and expected_c == 33:
|
||||
zeros_mask = torch.zeros_like(hint[:, :1])
|
||||
zeros_inpaint = torch.zeros_like(hint)
|
||||
hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1)
|
||||
|
||||
bs, c, t, h, w = hint.shape
|
||||
hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(
|
||||
orig_shape[0],
|
||||
orig_shape[1],
|
||||
orig_shape[-3],
|
||||
orig_shape[-2] // 2,
|
||||
2,
|
||||
orig_shape[-1] // 2,
|
||||
2,
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
||||
hidden_states = hidden_states.reshape(
|
||||
bs,
|
||||
t * ((h + 1) // 2) * ((w + 1) // 2),
|
||||
c * 4,
|
||||
)
|
||||
|
||||
expected_in = self.control_img_in.weight.shape[1]
|
||||
cur_in = hidden_states.shape[-1]
|
||||
if cur_in < expected_in:
|
||||
pad = torch.zeros(
|
||||
(hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
hidden_states = torch.cat([hidden_states, pad], dim=-1)
|
||||
elif cur_in > expected_in:
|
||||
hidden_states = hidden_states[:, :, :expected_in]
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask=None,
|
||||
guidance: torch.Tensor = None,
|
||||
hint=None,
|
||||
transformer_options={},
|
||||
base_model=None,
|
||||
**kwargs,
|
||||
):
|
||||
if base_model is None:
|
||||
raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.")
|
||||
|
||||
encoder_hidden_states_mask = attention_mask
|
||||
# Keep attention mask disabled inside Fun control blocks to mirror
|
||||
# VideoX behavior (they rely on seq lengths for RoPE, not masked attention).
|
||||
encoder_hidden_states_mask = None
|
||||
|
||||
hidden_states, img_ids, _ = base_model.process_img(x)
|
||||
hint_tokens = self._process_hint_tokens(hint)
|
||||
if hint_tokens is None:
|
||||
raise RuntimeError("Qwen Fun ControlNet requires a control hint image.")
|
||||
|
||||
if hint_tokens.shape[1] != hidden_states.shape[1]:
|
||||
max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1])
|
||||
hint_tokens = hint_tokens[:, :max_tokens]
|
||||
hidden_states = hidden_states[:, :max_tokens]
|
||||
img_ids = img_ids[:, :max_tokens]
|
||||
|
||||
txt_start = round(
|
||||
max(
|
||||
((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||
((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||
)
|
||||
)
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
|
||||
hidden_states = base_model.img_in(hidden_states)
|
||||
encoder_hidden_states = base_model.txt_norm(context)
|
||||
encoder_hidden_states = base_model.txt_in(encoder_hidden_states)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance * 1000
|
||||
|
||||
temb = (
|
||||
base_model.time_text_embed(timesteps, hidden_states)
|
||||
if guidance is None
|
||||
else base_model.time_text_embed(timesteps, guidance, hidden_states)
|
||||
)
|
||||
|
||||
c = self.control_img_in(hint_tokens)
|
||||
|
||||
for i, block in enumerate(self.control_blocks):
|
||||
if i == 0:
|
||||
c_in = block.before_proj(c) + hidden_states
|
||||
all_c = []
|
||||
else:
|
||||
all_c = list(torch.unbind(c, dim=0))
|
||||
c_in = all_c.pop(-1)
|
||||
|
||||
encoder_hidden_states, c_out = block(
|
||||
hidden_states=c_in,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
c_skip = block.after_proj(c_out) * self.hint_scale
|
||||
all_c += [c_skip, c_out]
|
||||
c = torch.stack(all_c, dim=0)
|
||||
|
||||
hints = torch.unbind(c, dim=0)[:-1]
|
||||
|
||||
controlnet_block_samples = [None] * self.main_model_double
|
||||
for local_idx, base_idx in enumerate(self.injection_layers):
|
||||
if local_idx < len(hints) and base_idx < len(controlnet_block_samples):
|
||||
controlnet_block_samples[base_idx] = hints[local_idx]
|
||||
|
||||
return {"input": controlnet_block_samples}
|
||||
|
||||
|
||||
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||
|
||||
@@ -5,7 +5,7 @@ import comfy.utils
|
||||
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||
sd_out = {}
|
||||
for k in sd:
|
||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight"))
|
||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
|
||||
sd_out[k_to] = sd[k]
|
||||
|
||||
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
||||
|
||||
@@ -19,12 +19,6 @@ def count_blocks(state_dict_keys, prefix_string):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def any_suffix_in(keys, prefix, main, suffix_list=[]):
|
||||
for x in suffix_list:
|
||||
if "{}{}{}".format(prefix, main, x) in keys:
|
||||
return True
|
||||
return False
|
||||
|
||||
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||
context_dim = None
|
||||
use_linear_in_transformer = False
|
||||
@@ -192,7 +186,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["meanflow_sum"] = False
|
||||
return dit_config
|
||||
|
||||
if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||
dit_config = {}
|
||||
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["image_model"] = "flux2"
|
||||
@@ -247,8 +241,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
|
||||
if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma
|
||||
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
||||
dit_config["image_model"] = "chroma"
|
||||
dit_config["in_channels"] = 64
|
||||
dit_config["out_channels"] = 64
|
||||
@@ -256,8 +249,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["out_dim"] = 3072
|
||||
dit_config["hidden_dim"] = 5120
|
||||
dit_config["n_layers"] = 5
|
||||
|
||||
if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance
|
||||
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
|
||||
dit_config["image_model"] = "chroma_radiance"
|
||||
dit_config["in_channels"] = 3
|
||||
dit_config["out_channels"] = 3
|
||||
@@ -267,7 +259,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["nerf_depth"] = 4
|
||||
dit_config["nerf_max_freqs"] = 8
|
||||
dit_config["nerf_tile_size"] = 512
|
||||
dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear"
|
||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
||||
dit_config["use_x0"] = True
|
||||
@@ -276,7 +268,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
|
||||
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
|
||||
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||
dit_config["txt_ids_dims"] = [1, 2]
|
||||
|
||||
|
||||
@@ -679,19 +679,18 @@ class ModelPatcher:
|
||||
for key in list(self.pinned):
|
||||
self.unpin_weight(key)
|
||||
|
||||
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
|
||||
def _load_list(self, prio_comfy_cast_weights=False):
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
default = False
|
||||
params = { name: param for name, param in m.named_parameters(recurse=False) }
|
||||
params = []
|
||||
skip = False
|
||||
for name, param in m.named_parameters(recurse=False):
|
||||
params.append(name)
|
||||
for name, param in m.named_parameters(recurse=True):
|
||||
if name not in params:
|
||||
default = True # default random weights in non leaf modules
|
||||
skip = True # skip random weights in non leaf modules
|
||||
break
|
||||
if default and default_device is not None:
|
||||
for param in params.values():
|
||||
param.data = param.data.to(device=default_device)
|
||||
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||
module_mem = comfy.model_management.module_size(m)
|
||||
module_offload_mem = module_mem
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
@@ -1496,7 +1495,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
#with pin and unpin syncrhonization which can be expensive for small weights
|
||||
#with a high layer rate (e.g. autoregressive LLMs).
|
||||
#prioritize the non-comfy weights (note the order reverse).
|
||||
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
|
||||
loading = self._load_list(prio_comfy_cast_weights=True)
|
||||
loading.sort(reverse=True)
|
||||
|
||||
for x in loading:
|
||||
@@ -1561,8 +1560,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
allocated_size += weight_size
|
||||
vbar.set_watermark_limit(allocated_size)
|
||||
|
||||
move_weight_functions(m, device_to)
|
||||
|
||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
||||
|
||||
self.model.device = device_to
|
||||
@@ -1582,7 +1579,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
|
||||
loading = self._load_list(prio_comfy_cast_weights=True)
|
||||
for x in loading:
|
||||
_, _, _, _, m, _ = x
|
||||
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
||||
@@ -1603,8 +1600,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
if unpatch_weights:
|
||||
self.partially_unload_ram(1e32)
|
||||
self.partially_unload(None, 1e32)
|
||||
for m in self.model.modules():
|
||||
move_weight_functions(m, device_to)
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
assert not force_patch_weights #See above
|
||||
|
||||
@@ -171,9 +171,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
|
||||
def process_tokens(self, tokens, device):
|
||||
end_token = self.special_tokens.get("end", None)
|
||||
pad_token = self.special_tokens.get("pad", -1)
|
||||
if end_token is None:
|
||||
cmp_token = pad_token
|
||||
cmp_token = self.special_tokens.get("pad", -1)
|
||||
else:
|
||||
cmp_token = end_token
|
||||
|
||||
@@ -187,21 +186,15 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
other_embeds = []
|
||||
eos = False
|
||||
index = 0
|
||||
left_pad = False
|
||||
for y in x:
|
||||
if isinstance(y, numbers.Integral):
|
||||
token = int(y)
|
||||
if index == 0 and token == pad_token:
|
||||
left_pad = True
|
||||
|
||||
if eos or (left_pad and token == pad_token):
|
||||
if eos:
|
||||
attention_mask.append(0)
|
||||
else:
|
||||
attention_mask.append(1)
|
||||
left_pad = False
|
||||
|
||||
token = int(y)
|
||||
tokens_temp += [token]
|
||||
if not eos and token == cmp_token and not left_pad:
|
||||
if not eos and token == cmp_token:
|
||||
if end_token is None:
|
||||
attention_mask[-1] = 0
|
||||
eos = True
|
||||
|
||||
@@ -710,15 +710,6 @@ class Flux(supported_models_base.BASE):
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
out_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
key_out = k
|
||||
if key_out.endswith("_norm.scale"):
|
||||
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||
out_sd[key_out] = state_dict[k]
|
||||
return out_sd
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
@@ -907,13 +898,11 @@ class HunyuanVideo(supported_models_base.BASE):
|
||||
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
|
||||
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
|
||||
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
|
||||
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight")
|
||||
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight")
|
||||
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
|
||||
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
|
||||
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
|
||||
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
|
||||
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
|
||||
if key_out.endswith(".scale"):
|
||||
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||
out_sd[key_out] = state_dict[k]
|
||||
return out_sd
|
||||
|
||||
@@ -1275,15 +1264,6 @@ class Hunyuan3Dv2(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.Hunyuan3Dv2
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
out_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
key_out = k
|
||||
if key_out.endswith(".scale"):
|
||||
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||
out_sd[key_out] = state_dict[k]
|
||||
return out_sd
|
||||
|
||||
def process_unet_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
@@ -1361,14 +1341,6 @@ class Chroma(supported_models_base.BASE):
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
out_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
key_out = k
|
||||
if key_out.endswith(".scale"):
|
||||
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||
out_sd[key_out] = state_dict[k]
|
||||
return out_sd
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Chroma(self, device=device)
|
||||
|
||||
@@ -10,6 +10,7 @@ import comfy.utils
|
||||
def sample_manual_loop_no_classes(
|
||||
model,
|
||||
ids=None,
|
||||
paddings=[],
|
||||
execution_dtype=None,
|
||||
cfg_scale: float = 2.0,
|
||||
temperature: float = 0.85,
|
||||
@@ -35,6 +36,9 @@ def sample_manual_loop_no_classes(
|
||||
|
||||
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
|
||||
embeds_batch = embeds.shape[0]
|
||||
for i, t in enumerate(paddings):
|
||||
attention_mask[i, :t] = 0
|
||||
attention_mask[i, t:] = 1
|
||||
|
||||
output_audio_codes = []
|
||||
past_key_values = []
|
||||
@@ -131,11 +135,13 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
|
||||
pos_pad = (len(negative) - len(positive))
|
||||
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
||||
|
||||
paddings = [pos_pad, neg_pad]
|
||||
ids = [positive, negative]
|
||||
else:
|
||||
paddings = []
|
||||
ids = [positive]
|
||||
|
||||
return sample_manual_loop_no_classes(model, ids, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||
|
||||
|
||||
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
|
||||
@@ -355,6 +355,13 @@ class RMSNorm(nn.Module):
|
||||
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
|
||||
if not isinstance(theta, list):
|
||||
theta = [theta]
|
||||
@@ -383,30 +390,20 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
|
||||
else:
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
sin_split = sin.shape[-1] // 2
|
||||
out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
|
||||
out.append((cos, sin))
|
||||
|
||||
if len(out) == 1:
|
||||
return out[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
org_dtype = xq.dtype
|
||||
cos = freqs_cis[0]
|
||||
sin = freqs_cis[1]
|
||||
nsin = freqs_cis[2]
|
||||
|
||||
q_embed = (xq * cos)
|
||||
q_split = q_embed.shape[-1] // 2
|
||||
q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin)
|
||||
q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin)
|
||||
|
||||
k_embed = (xk * cos)
|
||||
k_split = k_embed.shape[-1] // 2
|
||||
k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin)
|
||||
k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin)
|
||||
|
||||
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
||||
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
||||
return q_embed.to(org_dtype), k_embed.to(org_dtype)
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ def ltxv_te(*args, **kwargs):
|
||||
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
@@ -97,7 +97,6 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
token_weight_pairs = token_weight_pairs["gemma3_12b"]
|
||||
|
||||
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
|
||||
out = out[:, :, -torch.sum(extra["attention_mask"]).item():]
|
||||
out_device = out.device
|
||||
if comfy.model_management.should_use_bf16(self.execution_device):
|
||||
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
||||
@@ -139,7 +138,6 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
|
||||
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
||||
num_tokens = max(num_tokens, 64)
|
||||
return num_tokens * constant * 1024 * 1024
|
||||
|
||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
|
||||
@@ -675,10 +675,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
"ff_context.linear_in.bias": "txt_mlp.0.bias",
|
||||
"ff_context.linear_out.weight": "txt_mlp.2.weight",
|
||||
"ff_context.linear_out.bias": "txt_mlp.2.bias",
|
||||
"attn.norm_q.weight": "img_attn.norm.query_norm.weight",
|
||||
"attn.norm_k.weight": "img_attn.norm.key_norm.weight",
|
||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight",
|
||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight",
|
||||
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
@@ -701,8 +701,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
"norm.linear.bias": "modulation.lin.bias",
|
||||
"proj_out.weight": "linear2.weight",
|
||||
"proj_out.bias": "linear2.bias",
|
||||
"attn.norm_q.weight": "norm.query_norm.weight",
|
||||
"attn.norm_k.weight": "norm.key_norm.weight",
|
||||
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
|
||||
"attn.to_out.weight": "linear2.weight", # Flux 2
|
||||
}
|
||||
|
||||
@@ -57,7 +57,6 @@ class _RequestConfig:
|
||||
files: dict[str, Any] | list[tuple[str, Any]] | None
|
||||
multipart_parser: Callable | None
|
||||
max_retries: int
|
||||
max_retries_on_rate_limit: int
|
||||
retry_delay: float
|
||||
retry_backoff: float
|
||||
wait_label: str = "Waiting"
|
||||
@@ -66,7 +65,6 @@ class _RequestConfig:
|
||||
final_label_on_success: str | None = "Completed"
|
||||
progress_origin_ts: float | None = None
|
||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -80,7 +78,7 @@ class _PollUIState:
|
||||
active_since: float | None = None # start time of current active interval (None if queued)
|
||||
|
||||
|
||||
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
|
||||
@@ -105,8 +103,6 @@ async def sync_op(
|
||||
final_label_on_success: str | None = "Completed",
|
||||
progress_origin_ts: float | None = None,
|
||||
monitor_progress: bool = True,
|
||||
max_retries_on_rate_limit: int = 16,
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||
) -> M:
|
||||
raw = await sync_op_raw(
|
||||
cls,
|
||||
@@ -126,8 +122,6 @@ async def sync_op(
|
||||
final_label_on_success=final_label_on_success,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
monitor_progress=monitor_progress,
|
||||
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||
is_rate_limited=is_rate_limited,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
@@ -200,8 +194,6 @@ async def sync_op_raw(
|
||||
final_label_on_success: str | None = "Completed",
|
||||
progress_origin_ts: float | None = None,
|
||||
monitor_progress: bool = True,
|
||||
max_retries_on_rate_limit: int = 16,
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||
) -> dict[str, Any] | bytes:
|
||||
"""
|
||||
Make a single network request.
|
||||
@@ -230,8 +222,6 @@ async def sync_op_raw(
|
||||
final_label_on_success=final_label_on_success,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
price_extractor=price_extractor,
|
||||
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||
is_rate_limited=is_rate_limited,
|
||||
)
|
||||
return await _request_base(cfg, expect_binary=as_binary)
|
||||
|
||||
@@ -516,7 +506,7 @@ def _friendly_http_message(status: int, body: Any) -> str:
|
||||
if status == 409:
|
||||
return "There is a problem with your account. Please contact support@comfy.org."
|
||||
if status == 429:
|
||||
return "Rate Limit Exceeded: The server returned 429 after all retry attempts. Please wait and try again."
|
||||
return "Rate Limit Exceeded: Please try again later."
|
||||
try:
|
||||
if isinstance(body, dict):
|
||||
err = body.get("error")
|
||||
@@ -596,8 +586,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
|
||||
attempt = 0
|
||||
delay = cfg.retry_delay
|
||||
rate_limit_attempts = 0
|
||||
rate_limit_delay = cfg.retry_delay
|
||||
operation_succeeded: bool = False
|
||||
final_elapsed_seconds: int | None = None
|
||||
extracted_price: float | None = None
|
||||
@@ -665,14 +653,17 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
payload_headers["Content-Type"] = "application/json"
|
||||
payload_kw["json"] = cfg.data or {}
|
||||
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] request logging failed: %s", _log_e)
|
||||
|
||||
req_coro = sess.request(method, url, params=params, **payload_kw)
|
||||
req_task = asyncio.create_task(req_coro)
|
||||
@@ -697,33 +688,41 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
body = await resp.json()
|
||||
except (ContentTypeError, json.JSONDecodeError):
|
||||
body = await resp.text()
|
||||
should_retry = False
|
||||
wait_time = 0.0
|
||||
retry_label = ""
|
||||
is_rl = resp.status == 429 or (
|
||||
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
|
||||
)
|
||||
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
|
||||
rate_limit_attempts += 1
|
||||
wait_time = min(rate_limit_delay, 30.0)
|
||||
rate_limit_delay *= cfg.retry_backoff
|
||||
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
|
||||
should_retry = True
|
||||
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
|
||||
wait_time = delay
|
||||
delay *= cfg.retry_backoff
|
||||
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
|
||||
should_retry = True
|
||||
|
||||
if should_retry:
|
||||
if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
|
||||
logging.warning(
|
||||
"HTTP %s %s -> %s. Waiting %.2fs (%s).",
|
||||
"HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
|
||||
method,
|
||||
url,
|
||||
resp.status,
|
||||
wait_time,
|
||||
retry_label,
|
||||
delay,
|
||||
attempt,
|
||||
cfg.max_retries,
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=_friendly_http_message(resp.status, body),
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cfg.node_cls,
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
)
|
||||
delay *= cfg.retry_backoff
|
||||
continue
|
||||
msg = _friendly_http_message(resp.status, body)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
@@ -731,27 +730,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=f"HTTP {resp.status} ({retry_label}, will retry in {wait_time:.1f}s)",
|
||||
error_message=msg,
|
||||
)
|
||||
await sleep_with_interrupt(
|
||||
wait_time,
|
||||
cfg.node_cls,
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
)
|
||||
continue
|
||||
msg = _friendly_http_message(resp.status, body)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=msg,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
raise Exception(msg)
|
||||
|
||||
if expect_binary:
|
||||
@@ -771,14 +753,17 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
bytes_payload = bytes(buff)
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=bytes_payload,
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=bytes_payload,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
return bytes_payload
|
||||
else:
|
||||
try:
|
||||
@@ -795,39 +780,45 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=response_content_to_log,
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=response_content_to_log,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
return payload
|
||||
|
||||
except ProcessingInterrupted:
|
||||
logging.debug("Polling was interrupted by user")
|
||||
raise
|
||||
except (ClientError, OSError) as e:
|
||||
if (attempt - rate_limit_attempts) <= cfg.max_retries:
|
||||
if attempt <= cfg.max_retries:
|
||||
logging.warning(
|
||||
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
|
||||
method,
|
||||
url,
|
||||
delay,
|
||||
attempt - rate_limit_attempts,
|
||||
attempt,
|
||||
cfg.max_retries,
|
||||
str(e),
|
||||
)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] request error logging failed: %s", _log_e)
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cfg.node_cls,
|
||||
@@ -840,6 +831,23 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
continue
|
||||
diag = await _diagnose_connectivity()
|
||||
if not diag["internet_accessible"]:
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"LocalNetworkError: {str(e)}",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the API server due to local network issues. "
|
||||
"Please check your internet connection and try again."
|
||||
) from e
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
@@ -847,21 +855,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"LocalNetworkError: {str(e)}",
|
||||
error_message=f"ApiServerError: {str(e)}",
|
||||
)
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the API server due to local network issues. "
|
||||
"Please check your internet connection and try again."
|
||||
) from e
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"ApiServerError: {str(e)}",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||||
raise ApiServerError(
|
||||
f"The API server at {default_base_url()} is currently unreachable. "
|
||||
f"The service may be experiencing issues."
|
||||
|
||||
@@ -167,25 +167,27 @@ async def download_url_to_bytesio(
|
||||
with contextlib.suppress(Exception):
|
||||
dest.seek(0)
|
||||
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=f"[streamed {written} bytes to dest]",
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=f"[streamed {written} bytes to dest]",
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
except (ClientError, OSError) as e:
|
||||
if attempt <= max_retries:
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
await sleep_with_interrupt(delay, cls, None, None, None)
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any
|
||||
|
||||
import folder_paths
|
||||
|
||||
# Get the logger instance
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -90,41 +91,38 @@ def log_request_response(
|
||||
Filenames are sanitized and length-limited for cross-platform safety.
|
||||
If we still fail to write, we fall back to appending into api.log.
|
||||
"""
|
||||
log_dir = get_log_directory()
|
||||
filepath = _build_log_filepath(log_dir, operation_id, request_url)
|
||||
|
||||
log_content: list[str] = []
|
||||
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
|
||||
log_content.append(f"Operation ID: {operation_id}")
|
||||
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
|
||||
log_content.append(f"Method: {request_method}")
|
||||
log_content.append(f"URL: {request_url}")
|
||||
if request_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
|
||||
if request_params:
|
||||
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
|
||||
if request_data is not None:
|
||||
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
|
||||
|
||||
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
|
||||
if response_status_code is not None:
|
||||
log_content.append(f"Status Code: {response_status_code}")
|
||||
if response_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
|
||||
if response_content is not None:
|
||||
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
|
||||
if error_message:
|
||||
log_content.append(f"Error:\n{error_message}")
|
||||
|
||||
try:
|
||||
log_dir = get_log_directory()
|
||||
filepath = _build_log_filepath(log_dir, operation_id, request_url)
|
||||
|
||||
log_content: list[str] = []
|
||||
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
|
||||
log_content.append(f"Operation ID: {operation_id}")
|
||||
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
|
||||
log_content.append(f"Method: {request_method}")
|
||||
log_content.append(f"URL: {request_url}")
|
||||
if request_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
|
||||
if request_params:
|
||||
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
|
||||
if request_data is not None:
|
||||
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
|
||||
|
||||
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
|
||||
if response_status_code is not None:
|
||||
log_content.append(f"Status Code: {response_status_code}")
|
||||
if response_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
|
||||
if response_content is not None:
|
||||
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
|
||||
if error_message:
|
||||
log_content.append(f"Error:\n{error_message}")
|
||||
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(log_content))
|
||||
logger.debug("API log saved to: %s", filepath)
|
||||
except Exception as e:
|
||||
logger.error("Error writing API log to %s: %s", filepath, str(e))
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] log_request_response failed: %s", _log_e)
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(log_content))
|
||||
logger.debug("API log saved to: %s", filepath)
|
||||
except Exception as e:
|
||||
logger.error("Error writing API log to %s: %s", filepath, str(e))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -255,14 +255,17 @@ async def upload_file(
|
||||
monitor_task = asyncio.create_task(_monitor())
|
||||
sess: aiohttp.ClientSession | None = None
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_params=None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_params=None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug("[DEBUG] upload request logging failed: %s", e)
|
||||
|
||||
sess = aiohttp.ClientSession(timeout=timeout)
|
||||
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
|
||||
@@ -308,27 +311,31 @@ async def upload_file(
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
raise Exception(f"Failed to upload (HTTP {resp.status}).")
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content="File uploaded successfully.",
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content="File uploaded successfully.",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug("[DEBUG] upload response logging failed: %s", e)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
except (aiohttp.ClientError, OSError) as e:
|
||||
if attempt <= max_retries:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cls,
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, io, UI
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
def hex_to_rgb(hex_color: str) -> tuple[float, float, float]:
|
||||
hex_color = hex_color.lstrip("#")
|
||||
if len(hex_color) != 6:
|
||||
return (0.0, 0.0, 0.0)
|
||||
r = int(hex_color[0:2], 16) / 255.0
|
||||
g = int(hex_color[2:4], 16) / 255.0
|
||||
b = int(hex_color[4:6], 16) / 255.0
|
||||
return (r, g, b)
|
||||
|
||||
|
||||
class PainterNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Painter",
|
||||
display_name="Painter",
|
||||
category="image",
|
||||
inputs=[
|
||||
io.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
tooltip="Optional base image to paint over",
|
||||
),
|
||||
io.String.Input(
|
||||
"mask",
|
||||
default="",
|
||||
socketless=True,
|
||||
extra_dict={"widgetType": "PAINTER", "image_upload": True},
|
||||
),
|
||||
io.Int.Input(
|
||||
"width",
|
||||
default=512,
|
||||
min=64,
|
||||
max=4096,
|
||||
step=64,
|
||||
socketless=True,
|
||||
extra_dict={"hidden": True},
|
||||
),
|
||||
io.Int.Input(
|
||||
"height",
|
||||
default=512,
|
||||
min=64,
|
||||
max=4096,
|
||||
step=64,
|
||||
socketless=True,
|
||||
extra_dict={"hidden": True},
|
||||
),
|
||||
io.String.Input(
|
||||
"bg_color",
|
||||
default="#000000",
|
||||
socketless=True,
|
||||
extra_dict={"hidden": True, "widgetType": "COLOR"},
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output("IMAGE"),
|
||||
io.Mask.Output("MASK"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mask, width, height, bg_color="#000000", image=None) -> io.NodeOutput:
|
||||
if image is not None:
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
base_image = image
|
||||
else:
|
||||
h, w = height, width
|
||||
r, g, b = hex_to_rgb(bg_color)
|
||||
base_image = torch.zeros((1, h, w, 3), dtype=torch.float32)
|
||||
base_image[0, :, :, 0] = r
|
||||
base_image[0, :, :, 1] = g
|
||||
base_image[0, :, :, 2] = b
|
||||
|
||||
if mask and mask.strip():
|
||||
mask_path = folder_paths.get_annotated_filepath(mask)
|
||||
painter_img = node_helpers.pillow(Image.open, mask_path)
|
||||
painter_img = painter_img.convert("RGBA")
|
||||
|
||||
if painter_img.size != (w, h):
|
||||
painter_img = painter_img.resize((w, h), Image.LANCZOS)
|
||||
|
||||
painter_np = np.array(painter_img).astype(np.float32) / 255.0
|
||||
painter_rgb = painter_np[:, :, :3]
|
||||
painter_alpha = painter_np[:, :, 3:4]
|
||||
|
||||
mask_tensor = torch.from_numpy(painter_np[:, :, 3]).unsqueeze(0)
|
||||
|
||||
base_np = base_image[0].cpu().numpy()
|
||||
composited = painter_rgb * painter_alpha + base_np * (1.0 - painter_alpha)
|
||||
out_image = torch.from_numpy(composited).unsqueeze(0)
|
||||
else:
|
||||
mask_tensor = torch.zeros((1, h, w), dtype=torch.float32)
|
||||
out_image = base_image
|
||||
|
||||
return io.NodeOutput(out_image, mask_tensor, ui=UI.PreviewImage(out_image))
|
||||
|
||||
@classmethod
|
||||
def fingerprint_inputs(cls, mask, width, height, bg_color="#000000", image=None):
|
||||
if mask and mask.strip():
|
||||
mask_path = folder_paths.get_annotated_filepath(mask)
|
||||
if os.path.exists(mask_path):
|
||||
m = hashlib.sha256()
|
||||
with open(mask_path, "rb") as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
class PainterExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self):
|
||||
return [PainterNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint():
|
||||
return PainterExtension()
|
||||
@@ -1035,7 +1035,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
io.Boolean.Input(
|
||||
"offloading",
|
||||
default=False,
|
||||
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
|
||||
tooltip="Depth level for gradient checkpointing.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"existing_lora",
|
||||
@@ -1124,15 +1124,6 @@ class TrainLoraNode(io.ComfyNode):
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
|
||||
if mp.is_dynamic():
|
||||
if not bypass_mode:
|
||||
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
|
||||
bypass_mode = True
|
||||
offloading = True
|
||||
elif offloading:
|
||||
if not bypass_mode:
|
||||
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
|
||||
|
||||
# Prepare latents and compute counts
|
||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||
latents, dtype, bucket_mode
|
||||
|
||||
1
nodes.py
1
nodes.py
@@ -2435,7 +2435,6 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_lora_debug.py",
|
||||
"nodes_color.py",
|
||||
"nodes_toolkit.py",
|
||||
"nodes_painter.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
10
requirements-amd.txt
Normal file
10
requirements-amd.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
# AMD GPU requirements (ROCm)
|
||||
# Usage: pip install -r requirements-amd.txt
|
||||
#
|
||||
# Note: This is for AMD GPUs with ROCm support.
|
||||
# For experimental Windows/Linux support on RDNA 3/3.5/4, see README.md
|
||||
|
||||
--index-url https://download.pytorch.org/whl/rocm7.1
|
||||
--extra-index-url https://pypi.org/simple
|
||||
|
||||
-r requirements.txt
|
||||
7
requirements-intel.txt
Normal file
7
requirements-intel.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
# Intel GPU requirements (XPU - Arc GPUs)
|
||||
# Usage: pip install -r requirements-intel.txt
|
||||
|
||||
--index-url https://download.pytorch.org/whl/xpu
|
||||
--extra-index-url https://pypi.org/simple
|
||||
|
||||
-r requirements.txt
|
||||
6
requirements-nvidia.txt
Normal file
6
requirements-nvidia.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
# NVIDIA GPU requirements (CUDA 13.0)
|
||||
# Usage: pip install -r requirements-nvidia.txt
|
||||
|
||||
--extra-index-url https://download.pytorch.org/whl/cu130
|
||||
|
||||
-r requirements.txt
|
||||
Reference in New Issue
Block a user