mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 12:11:43 +00:00
Compare commits
125 Commits
master
...
worksplit-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
48deb15c0e | ||
|
|
4b93c4360f | ||
|
|
da3864436c | ||
|
|
b418fb1582 | ||
|
|
20803749c3 | ||
|
|
3fab720be9 | ||
|
|
afdddcee66 | ||
|
|
1d8e379f41 | ||
|
|
5f4fcd19e7 | ||
|
|
d52dcbc88f | ||
|
|
84f465e791 | ||
|
|
be35378986 | ||
|
|
f410d28b33 | ||
|
|
f4b99bc623 | ||
|
|
df2fd4c869 | ||
|
|
4661d1db5a | ||
|
|
b326a544d5 | ||
|
|
d89dd5f0b0 | ||
|
|
8cbbf0be6c | ||
|
|
c2115a4bac | ||
|
|
bb44c2ecb9 | ||
|
|
efcd8280d6 | ||
|
|
9e9c129cd0 | ||
|
|
ac14ee68c0 | ||
|
|
2c8f485434 | ||
|
|
383f9b34cb | ||
|
|
b0741c7e5b | ||
|
|
1489399cb5 | ||
|
|
3677943fa5 | ||
|
|
cfb63bfcd7 | ||
|
|
962c3c832c | ||
|
|
6ea69369ce | ||
|
|
b4f559b34d | ||
|
|
df122a7dba | ||
|
|
67e906aa64 | ||
|
|
382f84a826 | ||
|
|
9cca36fa2b | ||
|
|
5d5024296d | ||
|
|
3b90a30178 | ||
|
|
3c4104652b | ||
|
|
9855baaab3 | ||
|
|
d53479a197 | ||
|
|
443a795850 | ||
|
|
431dec8e53 | ||
|
|
44e053c26d | ||
|
|
1ae98932f1 | ||
|
|
0336b0ace8 | ||
|
|
8ae25235ec | ||
|
|
9726eac475 | ||
|
|
272e8d42c1 | ||
|
|
6211d2be5a | ||
|
|
8be711715c | ||
|
|
b5cccf1325 | ||
|
|
2a54a904f4 | ||
|
|
ed6f92c975 | ||
|
|
adc66c0698 | ||
|
|
ccd5c01e5a | ||
|
|
2fa9affcc1 | ||
|
|
407a5a656f | ||
|
|
9ce9ff8ef8 | ||
|
|
63567c0ce8 | ||
|
|
a786ce5ead | ||
|
|
4879b47648 | ||
|
|
5ccec33c22 | ||
|
|
219d3cd0d0 | ||
|
|
c4ba399475 | ||
|
|
cc928a786d | ||
|
|
6e144b98c4 | ||
|
|
6dca17bd2d | ||
|
|
5080105c23 | ||
|
|
093914a247 | ||
|
|
605893d3cf | ||
|
|
048f4f0b3a | ||
|
|
d2504fb701 | ||
|
|
b03763bca6 | ||
|
|
476aa79b64 | ||
|
|
441cfd1a7a | ||
|
|
99a5c1068a | ||
|
|
02747cde7d | ||
|
|
0b3233b4e2 | ||
|
|
eda866bf51 | ||
|
|
e3298b84de | ||
|
|
c7feef9060 | ||
|
|
51af7fa1b4 | ||
|
|
46969c380a | ||
|
|
5db4277449 | ||
|
|
02a4d0ad7d | ||
|
|
ef137ac0b6 | ||
|
|
328d4f16a9 | ||
|
|
bdbcb85b8d | ||
|
|
6c9e94bae7 | ||
|
|
bfce723311 | ||
|
|
31f5458938 | ||
|
|
2145a202eb | ||
|
|
25818dc848 | ||
|
|
198953cd08 | ||
|
|
ec16ee2f39 | ||
|
|
d5088072fb | ||
|
|
8d4b50158e | ||
|
|
e88c6c03ff | ||
|
|
d3cf2b7b24 | ||
|
|
7448f02b7c | ||
|
|
871258aa72 | ||
|
|
66838ebd39 | ||
|
|
7333281698 | ||
|
|
3cd4c5cb0a | ||
|
|
11c6d56037 | ||
|
|
216fea15ee | ||
|
|
58bf8815c8 | ||
|
|
1b38f5bf57 | ||
|
|
2724ac4a60 | ||
|
|
f48f90e471 | ||
|
|
6463c39ce0 | ||
|
|
0a7e2ae787 | ||
|
|
03a97b604a | ||
|
|
4446c86052 | ||
|
|
8270ff312f | ||
|
|
db2d7ad9ba | ||
|
|
6620d86318 | ||
|
|
111fd0cadf | ||
|
|
776aa734e1 | ||
|
|
5a2ad032cb | ||
|
|
d44295ef71 | ||
|
|
bf21be066f | ||
|
|
72bbf49349 |
@@ -1,2 +0,0 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||
pause
|
||||
@@ -139,9 +139,9 @@ Example:
|
||||
"_quantization_metadata": {
|
||||
"format_version": "1.0",
|
||||
"layers": {
|
||||
"model.layers.0.mlp.up_proj": {"format": "float8_e4m3fn"},
|
||||
"model.layers.0.mlp.down_proj": {"format": "float8_e4m3fn"},
|
||||
"model.layers.1.mlp.up_proj": {"format": "float8_e4m3fn"}
|
||||
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
|
||||
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
|
||||
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -165,4 +165,4 @@ Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_s
|
||||
3. **Compute scales**: Derive `input_scale` from collected statistics
|
||||
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
|
||||
|
||||
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
|
||||
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
|
||||
@@ -182,7 +182,7 @@
|
||||
]
|
||||
},
|
||||
"widgets_values": [
|
||||
0
|
||||
50
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -316,7 +316,7 @@
|
||||
"step": 1
|
||||
},
|
||||
"widgets_values": [
|
||||
0
|
||||
30
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
|
||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
|
||||
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use. All other devices will not be visible.")
|
||||
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
|
||||
cm_group = parser.add_mutually_exclusive_group()
|
||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||
|
||||
@@ -15,13 +15,14 @@
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from enum import Enum
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
import copy
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_detection
|
||||
@@ -38,7 +39,7 @@ import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet
|
||||
import comfy.ldm.qwen_image.controlnet
|
||||
import comfy.cldm.dit_embedder
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Union
|
||||
if TYPE_CHECKING:
|
||||
from comfy.hooks import HookGroup
|
||||
|
||||
@@ -64,6 +65,18 @@ class StrengthType(Enum):
|
||||
CONSTANT = 1
|
||||
LINEAR_UP = 2
|
||||
|
||||
class ControlIsolation:
|
||||
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
|
||||
def __init__(self, control: ControlBase):
|
||||
self.control = control
|
||||
self.orig_previous_controlnet = control.previous_controlnet
|
||||
|
||||
def __enter__(self):
|
||||
self.control.previous_controlnet = None
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.control.previous_controlnet = self.orig_previous_controlnet
|
||||
|
||||
class ControlBase:
|
||||
def __init__(self):
|
||||
self.cond_hint_original = None
|
||||
@@ -77,7 +90,7 @@ class ControlBase:
|
||||
self.compression_ratio = 8
|
||||
self.upscale_algorithm = 'nearest-exact'
|
||||
self.extra_args = {}
|
||||
self.previous_controlnet = None
|
||||
self.previous_controlnet: Union[ControlBase, None] = None
|
||||
self.extra_conds = []
|
||||
self.strength_type = StrengthType.CONSTANT
|
||||
self.concat_mask = False
|
||||
@@ -85,6 +98,7 @@ class ControlBase:
|
||||
self.extra_concat = None
|
||||
self.extra_hooks: HookGroup = None
|
||||
self.preprocess_image = lambda a: a
|
||||
self.multigpu_clones: dict[torch.device, ControlBase] = {}
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||
self.cond_hint_original = cond_hint
|
||||
@@ -111,17 +125,38 @@ class ControlBase:
|
||||
def cleanup(self):
|
||||
if self.previous_controlnet is not None:
|
||||
self.previous_controlnet.cleanup()
|
||||
|
||||
for device_cnet in self.multigpu_clones.values():
|
||||
with ControlIsolation(device_cnet):
|
||||
device_cnet.cleanup()
|
||||
self.cond_hint = None
|
||||
self.extra_concat = None
|
||||
self.timestep_range = None
|
||||
|
||||
def get_models(self):
|
||||
out = []
|
||||
for device_cnet in self.multigpu_clones.values():
|
||||
out += device_cnet.get_models_only_self()
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_models()
|
||||
return out
|
||||
|
||||
def get_models_only_self(self):
|
||||
'Calls get_models, but temporarily sets previous_controlnet to None.'
|
||||
with ControlIsolation(self):
|
||||
return self.get_models()
|
||||
|
||||
def get_instance_for_device(self, device):
|
||||
'Returns instance of this Control object intended for selected device.'
|
||||
return self.multigpu_clones.get(device, self)
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
'''
|
||||
Create deep clone of Control object where model(s) is set to other devices.
|
||||
|
||||
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
|
||||
'''
|
||||
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
|
||||
|
||||
def get_extra_hooks(self):
|
||||
out = []
|
||||
if self.extra_hooks is not None:
|
||||
@@ -130,7 +165,7 @@ class ControlBase:
|
||||
out += self.previous_controlnet.get_extra_hooks()
|
||||
return out
|
||||
|
||||
def copy_to(self, c):
|
||||
def copy_to(self, c: ControlBase):
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
c.strength = self.strength
|
||||
c.timestep_percent_range = self.timestep_percent_range
|
||||
@@ -284,6 +319,14 @@ class ControlNet(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
c = self.copy()
|
||||
c.control_model = copy.deepcopy(c.control_model)
|
||||
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
if autoregister:
|
||||
self.multigpu_clones[load_device] = c
|
||||
return c
|
||||
|
||||
def get_models(self):
|
||||
out = super().get_models()
|
||||
out.append(self.control_model_wrapped)
|
||||
@@ -906,6 +949,14 @@ class T2IAdapter(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
c = self.copy()
|
||||
c.t2i_model = copy.deepcopy(c.t2i_model)
|
||||
c.device = load_device
|
||||
if autoregister:
|
||||
self.multigpu_clones[load_device] = c
|
||||
return c
|
||||
|
||||
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||
compression_ratio = 8
|
||||
upscale_algorithm = 'nearest-exact'
|
||||
|
||||
@@ -1,303 +0,0 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0
|
||||
if not comfy.model_management.supports_fp64(pos.device):
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = pos.device
|
||||
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = torch.einsum("...n,d->...nd", pos.to(device), omega)
|
||||
out = torch.stack([torch.cos(out), torch.sin(out)], dim=0)
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
rot_dim = freqs_cis.shape[-1]
|
||||
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
|
||||
cos_ = freqs_cis[0]
|
||||
sin_ = freqs_cis[1]
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
x_rotated = torch.cat((-x2, x1), dim=-1)
|
||||
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
|
||||
|
||||
class ErnieImageEmbedND3(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: tuple):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = list(axes_dim)
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
|
||||
emb = emb.unsqueeze(3) # [2, B, S, 1, head_dim//2]
|
||||
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim]
|
||||
|
||||
class ErnieImagePatchEmbedDynamic(nn.Module):
|
||||
def __init__(self, in_channels: int, embed_dim: int, patch_size: int, operations, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
batch_size, dim, height, width = x.shape
|
||||
return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous()
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool = False):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
half_dim = self.num_channels // 2
|
||||
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) / half_dim
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
if self.flip_sin_to_cos:
|
||||
emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
|
||||
else:
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
return emb
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_channels: int, time_embed_dim: int, operations, device=None, dtype=None):
|
||||
super().__init__()
|
||||
Linear = operations.Linear
|
||||
self.linear_1 = Linear(in_channels, time_embed_dim, bias=True, device=device, dtype=dtype)
|
||||
self.act = nn.SiLU()
|
||||
self.linear_2 = Linear(time_embed_dim, time_embed_dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
||||
sample = self.linear_1(sample)
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
class ErnieImageAttention(nn.Module):
|
||||
def __init__(self, query_dim: int, heads: int, dim_head: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.head_dim = dim_head
|
||||
self.inner_dim = heads * dim_head
|
||||
|
||||
Linear = operations.Linear
|
||||
RMSNorm = operations.RMSNorm
|
||||
|
||||
self.to_q = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.to_k = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.to_v = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype)
|
||||
|
||||
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=True, device=device, dtype=dtype)
|
||||
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=True, device=device, dtype=dtype)
|
||||
|
||||
self.to_out = nn.ModuleList([Linear(self.inner_dim, query_dim, bias=False, device=device, dtype=dtype)])
|
||||
|
||||
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None, image_rotary_emb: torch.Tensor = None) -> torch.Tensor:
|
||||
B, S, _ = x.shape
|
||||
|
||||
q_flat = self.to_q(x)
|
||||
k_flat = self.to_k(x)
|
||||
v_flat = self.to_v(x)
|
||||
|
||||
query = q_flat.view(B, S, self.heads, self.head_dim)
|
||||
key = k_flat.view(B, S, self.heads, self.head_dim)
|
||||
|
||||
query = self.norm_q(query)
|
||||
key = self.norm_k(key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
query, key = query.to(x.dtype), key.to(x.dtype)
|
||||
|
||||
q_flat = query.reshape(B, S, -1)
|
||||
k_flat = key.reshape(B, S, -1)
|
||||
|
||||
hidden_states = optimized_attention(q_flat, k_flat, v_flat, self.heads, mask=attention_mask)
|
||||
|
||||
return self.to_out[0](hidden_states)
|
||||
|
||||
class ErnieImageFeedForward(nn.Module):
|
||||
def __init__(self, hidden_size: int, ffn_hidden_size: int, operations, device=None, dtype=None):
|
||||
super().__init__()
|
||||
Linear = operations.Linear
|
||||
self.gate_proj = Linear(hidden_size, ffn_hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.up_proj = Linear(hidden_size, ffn_hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.linear_fc2 = Linear(ffn_hidden_size, hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x)))
|
||||
|
||||
class ErnieImageSharedAdaLNBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
|
||||
super().__init__()
|
||||
RMSNorm = operations.RMSNorm
|
||||
|
||||
self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps, device=device, dtype=dtype)
|
||||
self.self_attention = ErnieImageAttention(
|
||||
query_dim=hidden_size,
|
||||
dim_head=hidden_size // num_heads,
|
||||
heads=num_heads,
|
||||
eps=eps,
|
||||
operations=operations,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps, device=device, dtype=dtype)
|
||||
self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size, operations=operations, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, rotary_pos_emb, temb, attention_mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb
|
||||
|
||||
residual = x
|
||||
x_norm = self.adaLN_sa_ln(x)
|
||||
x_norm = (x_norm.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
|
||||
|
||||
attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
|
||||
x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
|
||||
|
||||
residual = x
|
||||
x_norm = self.adaLN_mlp_ln(x)
|
||||
x_norm = (x_norm.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
|
||||
|
||||
return residual + (gate_mlp.float() * self.mlp(x_norm).float()).to(x.dtype)
|
||||
|
||||
class ErnieImageAdaLNContinuous(nn.Module):
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
|
||||
super().__init__()
|
||||
LayerNorm = operations.LayerNorm
|
||||
Linear = operations.Linear
|
||||
self.norm = LayerNorm(hidden_size, elementwise_affine=False, eps=eps, device=device, dtype=dtype)
|
||||
self.linear = Linear(hidden_size, hidden_size * 2, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
|
||||
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
|
||||
x = self.norm(x)
|
||||
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
return x
|
||||
|
||||
class ErnieImageModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 4096,
|
||||
num_attention_heads: int = 32,
|
||||
num_layers: int = 36,
|
||||
ffn_hidden_size: int = 12288,
|
||||
in_channels: int = 128,
|
||||
out_channels: int = 128,
|
||||
patch_size: int = 1,
|
||||
text_in_dim: int = 3072,
|
||||
rope_theta: int = 256,
|
||||
rope_axes_dim: tuple = (32, 48, 48),
|
||||
eps: float = 1e-6,
|
||||
qk_layernorm: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_attention_heads
|
||||
self.head_dim = hidden_size // num_attention_heads
|
||||
self.patch_size = patch_size
|
||||
self.out_channels = out_channels
|
||||
|
||||
Linear = operations.Linear
|
||||
|
||||
self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size, operations, device, dtype)
|
||||
self.text_proj = Linear(text_in_dim, hidden_size, bias=False, device=device, dtype=dtype) if text_in_dim != hidden_size else None
|
||||
|
||||
self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False)
|
||||
self.time_embedding = TimestepEmbedding(hidden_size, hidden_size, operations, device, dtype)
|
||||
|
||||
self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
Linear(hidden_size, 6 * hidden_size, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, operations, device, dtype)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps, operations, device, dtype)
|
||||
self.final_linear = Linear(hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, timesteps, context, **kwargs):
|
||||
device, dtype = x.device, x.dtype
|
||||
B, C, H, W = x.shape
|
||||
p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size
|
||||
N_img = Hp * Wp
|
||||
|
||||
img_bsh = self.x_embedder(x)
|
||||
|
||||
text_bth = context
|
||||
if self.text_proj is not None and text_bth.numel() > 0:
|
||||
text_bth = self.text_proj(text_bth)
|
||||
Tmax = text_bth.shape[1]
|
||||
|
||||
hidden_states = torch.cat([img_bsh, text_bth], dim=1)
|
||||
|
||||
text_ids = torch.zeros((B, Tmax, 3), device=device, dtype=torch.float32)
|
||||
text_ids[:, :, 0] = torch.linspace(0, Tmax - 1, steps=Tmax, device=x.device, dtype=torch.float32)
|
||||
index = float(Tmax)
|
||||
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
|
||||
h_len, w_len = float(Hp), float(Wp)
|
||||
h_offset, w_offset = 0.0, 0.0
|
||||
|
||||
if rope_options is not None:
|
||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||
index += rope_options.get("shift_t", 0.0)
|
||||
h_offset += rope_options.get("shift_y", 0.0)
|
||||
w_offset += rope_options.get("shift_x", 0.0)
|
||||
|
||||
image_ids = torch.zeros((Hp, Wp, 3), device=device, dtype=torch.float32)
|
||||
image_ids[:, :, 0] = image_ids[:, :, 1] + index
|
||||
image_ids[:, :, 1] = image_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=Hp, device=device, dtype=torch.float32).unsqueeze(1)
|
||||
image_ids[:, :, 2] = image_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=Wp, device=device, dtype=torch.float32).unsqueeze(0)
|
||||
|
||||
image_ids = image_ids.view(1, N_img, 3).expand(B, -1, -1)
|
||||
|
||||
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype)
|
||||
del image_ids, text_ids
|
||||
|
||||
sample = self.time_proj(timesteps).to(dtype)
|
||||
c = self.time_embedding(sample)
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
|
||||
t.unsqueeze(1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
]
|
||||
|
||||
temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp]
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, rotary_pos_emb, temb)
|
||||
|
||||
hidden_states = self.final_norm(hidden_states, c).type_as(hidden_states)
|
||||
|
||||
patches = self.final_linear(hidden_states)[:, :N_img, :]
|
||||
output = (
|
||||
patches.view(B, Hp, Wp, p, p, self.out_channels)
|
||||
.permute(0, 5, 1, 3, 2, 4)
|
||||
.contiguous()
|
||||
.view(B, self.out_channels, H, W)
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -16,7 +16,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
if not comfy.model_management.supports_fp64(pos.device):
|
||||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = pos.device
|
||||
|
||||
@@ -90,7 +90,7 @@ class HeatmapHead(torch.nn.Module):
|
||||
origin_max = np.max(hm[k])
|
||||
dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32)
|
||||
dr[border:-border, border:-border] = hm[k].copy()
|
||||
dr = gaussian_filter(dr, sigma=2.0, truncate=2.5)
|
||||
dr = gaussian_filter(dr, sigma=2.0)
|
||||
hm[k] = dr[border:-border, border:-border].copy()
|
||||
cur_max = np.max(hm[k])
|
||||
if cur_max > 0:
|
||||
|
||||
@@ -53,7 +53,6 @@ import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
import comfy.ldm.ace.ace_step15
|
||||
import comfy.ldm.rt_detr.rtdetr_v4
|
||||
import comfy.ldm.ernie.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -1963,14 +1962,3 @@ class Kandinsky5Image(Kandinsky5):
|
||||
class RT_DETR_v4(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
|
||||
|
||||
class ErnieImage(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
@@ -713,11 +713,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
|
||||
return dit_config
|
||||
|
||||
if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ernie"
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import psutil
|
||||
import logging
|
||||
@@ -32,6 +33,11 @@ import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.quant_ops
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||
@@ -206,6 +212,25 @@ def get_torch_device():
|
||||
else:
|
||||
return torch.device(torch.cuda.current_device())
|
||||
|
||||
def get_all_torch_devices(exclude_current=False):
|
||||
global cpu_state
|
||||
devices = []
|
||||
if cpu_state == CPUState.GPU:
|
||||
if is_nvidia():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
devices.append(torch.device(i))
|
||||
elif is_intel_xpu():
|
||||
for i in range(torch.xpu.device_count()):
|
||||
devices.append(torch.device(i))
|
||||
elif is_ascend_npu():
|
||||
for i in range(torch.npu.device_count()):
|
||||
devices.append(torch.device(i))
|
||||
else:
|
||||
devices.append(get_torch_device())
|
||||
if exclude_current:
|
||||
devices.remove(get_torch_device())
|
||||
return devices
|
||||
|
||||
def get_total_memory(dev=None, torch_total_too=False):
|
||||
global directml_enabled
|
||||
if dev is None:
|
||||
@@ -494,9 +519,13 @@ try:
|
||||
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||
except:
|
||||
logging.warning("Could not pick default device.")
|
||||
try:
|
||||
for device in get_all_torch_devices(exclude_current=True):
|
||||
logging.info("Device: {}".format(get_torch_device_name(device)))
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
current_loaded_models = []
|
||||
current_loaded_models: list[LoadedModel] = []
|
||||
|
||||
def module_size(module):
|
||||
module_mem = 0
|
||||
@@ -529,7 +558,7 @@ def module_mmap_residency(module, free=False):
|
||||
return mmap_touched_mem, module_mem
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
def __init__(self, model: ModelPatcher):
|
||||
self._set_model(model)
|
||||
self.device = model.load_device
|
||||
self.real_model = None
|
||||
@@ -537,7 +566,7 @@ class LoadedModel:
|
||||
self.model_finalizer = None
|
||||
self._patcher_finalizer = None
|
||||
|
||||
def _set_model(self, model):
|
||||
def _set_model(self, model: ModelPatcher):
|
||||
self._model = weakref.ref(model)
|
||||
if model.parent is not None:
|
||||
self._parent_model = weakref.ref(model.parent)
|
||||
@@ -548,6 +577,7 @@ class LoadedModel:
|
||||
model = self._parent_model()
|
||||
if model is not None:
|
||||
self._set_model(model)
|
||||
self.device = model.load_device
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
@@ -1732,21 +1762,6 @@ def supports_mxfp8_compute(device=None):
|
||||
|
||||
return True
|
||||
|
||||
def supports_fp64(device=None):
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
return False
|
||||
|
||||
if is_directml_enabled():
|
||||
return False
|
||||
|
||||
if is_ixuca():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def extended_fp16_support():
|
||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||
if torch_version_numeric < (2, 7):
|
||||
@@ -1794,7 +1809,34 @@ def soft_empty_cache(force=False):
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
for device in get_all_torch_devices():
|
||||
free_memory(1e30, device)
|
||||
|
||||
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
|
||||
'Unload only model and its clones - primarily for multigpu cloning purposes.'
|
||||
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
|
||||
additional_models = []
|
||||
if unload_additional_models:
|
||||
additional_models = model.get_nested_additional_models()
|
||||
keep_loaded = []
|
||||
for loaded_model in initial_keep_loaded:
|
||||
if loaded_model.model is not None:
|
||||
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||
continue
|
||||
# check additional models if they are a match
|
||||
skip = False
|
||||
for add_model in additional_models:
|
||||
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||
skip = True
|
||||
break
|
||||
if skip:
|
||||
continue
|
||||
keep_loaded.append(loaded_model)
|
||||
if not all_devices:
|
||||
free_memory(1e30, get_torch_device(), keep_loaded)
|
||||
else:
|
||||
for device in get_all_torch_devices():
|
||||
free_memory(1e30, device, keep_loaded)
|
||||
|
||||
def debug_memory_summary():
|
||||
if is_amd() or is_nvidia():
|
||||
|
||||
@@ -23,6 +23,7 @@ import inspect
|
||||
import logging
|
||||
import math
|
||||
import uuid
|
||||
import copy
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
@@ -75,12 +76,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
||||
def create_model_options_clone(orig_model_options: dict):
|
||||
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
||||
|
||||
def create_hook_patches_clone(orig_hook_patches):
|
||||
def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
|
||||
new_hook_patches = {}
|
||||
for hook_ref in orig_hook_patches:
|
||||
new_hook_patches[hook_ref] = {}
|
||||
for k in orig_hook_patches[hook_ref]:
|
||||
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
|
||||
if copy_tuples:
|
||||
for i in range(len(new_hook_patches[hook_ref][k])):
|
||||
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
|
||||
return new_hook_patches
|
||||
|
||||
def wipe_lowvram_weight(m):
|
||||
@@ -272,7 +276,10 @@ class ModelPatcher:
|
||||
self.is_clip = False
|
||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||
|
||||
self.cached_patcher_init: tuple[Callable, tuple] | None = None
|
||||
self.cached_patcher_init: tuple[Callable, tuple] | tuple[Callable, tuple, int] | None = None
|
||||
self.is_multigpu_base_clone = False
|
||||
self.clone_base_uuid = uuid.uuid4()
|
||||
|
||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
@@ -326,6 +333,8 @@ class ModelPatcher:
|
||||
if self.cached_patcher_init is None:
|
||||
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
|
||||
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
||||
if len(self.cached_patcher_init) > 2:
|
||||
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
|
||||
model_override = temp_model_patcher.get_clone_model_override()
|
||||
if model_override is None:
|
||||
model_override = self.get_clone_model_override()
|
||||
@@ -384,19 +393,98 @@ class ModelPatcher:
|
||||
n.hook_mode = self.hook_mode
|
||||
|
||||
n.cached_patcher_init = self.cached_patcher_init
|
||||
n.is_multigpu_base_clone = self.is_multigpu_base_clone
|
||||
n.clone_base_uuid = self.clone_base_uuid
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||
callback(self, n)
|
||||
return n
|
||||
|
||||
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
|
||||
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
|
||||
comfy.model_management.unload_model_and_clones(self)
|
||||
n = self.clone()
|
||||
# set load device, if present
|
||||
if new_load_device is not None:
|
||||
n.load_device = new_load_device
|
||||
if self.cached_patcher_init is not None:
|
||||
temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
|
||||
if len(self.cached_patcher_init) > 2:
|
||||
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
|
||||
n.model = temp_model_patcher.model
|
||||
else:
|
||||
n.model = copy.deepcopy(n.model)
|
||||
# unlike for normal clone, backup dicts that shared same ref should not;
|
||||
# otherwise, patchers that have deep copies of base models will erroneously influence each other.
|
||||
n.backup = copy.deepcopy(n.backup)
|
||||
n.object_patches_backup = copy.deepcopy(n.object_patches_backup)
|
||||
n.hook_backup = copy.deepcopy(n.hook_backup)
|
||||
# multigpu clone should not have multigpu additional_models entry
|
||||
n.remove_additional_models("multigpu")
|
||||
# multigpu_clone all stored additional_models; make sure circular references are properly handled
|
||||
if models_cache is None:
|
||||
models_cache = {}
|
||||
for key, model_list in n.additional_models.items():
|
||||
for i in range(len(model_list)):
|
||||
add_model = n.additional_models[key][i]
|
||||
if add_model.clone_base_uuid not in models_cache:
|
||||
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
|
||||
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
|
||||
callback(self, n)
|
||||
return n
|
||||
|
||||
def match_multigpu_clones(self):
|
||||
multigpu_models = self.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) > 0:
|
||||
new_multigpu_models = []
|
||||
for mm in multigpu_models:
|
||||
# clone main model, but bring over relevant props from existing multigpu clone
|
||||
n = self.clone()
|
||||
n.load_device = mm.load_device
|
||||
n.backup = mm.backup
|
||||
n.object_patches_backup = mm.object_patches_backup
|
||||
n.hook_backup = mm.hook_backup
|
||||
n.model = mm.model
|
||||
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
|
||||
n.remove_additional_models("multigpu")
|
||||
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
|
||||
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
|
||||
# figure out which additional models are not present in multigpu clone
|
||||
models_cache = {}
|
||||
for mm_add_model in mm.get_additional_models():
|
||||
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
|
||||
remove_models_uuids = set(list(models_cache.keys()))
|
||||
for key, model_list in orig_additional_models.items():
|
||||
for orig_add_model in model_list:
|
||||
if orig_add_model.clone_base_uuid not in models_cache:
|
||||
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
|
||||
existing_list = n.get_additional_models_with_key(key)
|
||||
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
|
||||
n.set_additional_models(key, existing_list)
|
||||
if orig_add_model.clone_base_uuid in remove_models_uuids:
|
||||
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
|
||||
# remove duplicate additional models
|
||||
for key, model_list in n.additional_models.items():
|
||||
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
|
||||
n.set_additional_models(key, new_model_list)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
|
||||
callback(self, n)
|
||||
new_multigpu_models.append(n)
|
||||
self.set_additional_models("multigpu", new_multigpu_models)
|
||||
|
||||
def is_clone(self, other):
|
||||
if hasattr(other, 'model') and self.model is other.model:
|
||||
return True
|
||||
return False
|
||||
|
||||
def clone_has_same_weights(self, clone: 'ModelPatcher'):
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
|
||||
if allow_multigpu:
|
||||
if self.clone_base_uuid != clone.clone_base_uuid:
|
||||
return False
|
||||
else:
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
|
||||
if self.current_hooks != clone.current_hooks:
|
||||
return False
|
||||
@@ -1167,7 +1255,7 @@ class ModelPatcher:
|
||||
return self.additional_models.get(key, [])
|
||||
|
||||
def get_additional_models(self):
|
||||
all_models = []
|
||||
all_models: list[ModelPatcher] = []
|
||||
for models in self.additional_models.values():
|
||||
all_models.extend(models)
|
||||
return all_models
|
||||
@@ -1221,9 +1309,13 @@ class ModelPatcher:
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
||||
callback(self)
|
||||
|
||||
def prepare_state(self, timestep):
|
||||
def prepare_state(self, timestep, model_options, ignore_multigpu=False):
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
||||
callback(self, timestep)
|
||||
callback(self, timestep, model_options, ignore_multigpu)
|
||||
if not ignore_multigpu and "multigpu_clones" in model_options:
|
||||
for p in model_options["multigpu_clones"].values():
|
||||
p: ModelPatcher
|
||||
p.prepare_state(timestep, model_options, ignore_multigpu=True)
|
||||
|
||||
def restore_hook_patches(self):
|
||||
if self.hook_patches_backup is not None:
|
||||
@@ -1236,12 +1328,18 @@ class ModelPatcher:
|
||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
||||
curr_t = t[0]
|
||||
reset_current_hooks = False
|
||||
multigpu_kf_changed_cache = None
|
||||
transformer_options = model_options.get("transformer_options", {})
|
||||
for hook in hook_group.hooks:
|
||||
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
||||
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
||||
# this will cause the weights to be recalculated when sampling
|
||||
if changed:
|
||||
# cache changed for multigpu usage
|
||||
if "multigpu_clones" in model_options:
|
||||
if multigpu_kf_changed_cache is None:
|
||||
multigpu_kf_changed_cache = []
|
||||
multigpu_kf_changed_cache.append(hook)
|
||||
# reset current_hooks if contains hook that changed
|
||||
if self.current_hooks is not None:
|
||||
for current_hook in self.current_hooks.hooks:
|
||||
@@ -1253,6 +1351,28 @@ class ModelPatcher:
|
||||
self.cached_hook_patches.pop(cached_group)
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
if "multigpu_clones" in model_options:
|
||||
for p in model_options["multigpu_clones"].values():
|
||||
p: ModelPatcher
|
||||
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
|
||||
|
||||
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
|
||||
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
|
||||
if kf_changed_cache is None:
|
||||
return
|
||||
reset_current_hooks = False
|
||||
# reset current_hooks if contains hook that changed
|
||||
for hook in kf_changed_cache:
|
||||
if self.current_hooks is not None:
|
||||
for current_hook in self.current_hooks.hooks:
|
||||
if current_hook == hook:
|
||||
reset_current_hooks = True
|
||||
break
|
||||
for cached_group in list(self.cached_hook_patches.keys()):
|
||||
if cached_group.contains(hook):
|
||||
self.cached_hook_patches.pop(cached_group)
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
|
||||
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
||||
registered: comfy.hooks.HookGroup = None):
|
||||
|
||||
230
comfy/multigpu.py
Normal file
230
comfy/multigpu.py
Normal file
@@ -0,0 +1,230 @@
|
||||
from __future__ import annotations
|
||||
import queue
|
||||
import threading
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
import comfy.utils
|
||||
import comfy.patcher_extension
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class MultiGPUThreadPool:
|
||||
"""Persistent thread pool for multi-GPU work distribution.
|
||||
|
||||
Maintains one worker thread per extra GPU device. Each thread calls
|
||||
torch.cuda.set_device() once at startup so that compiled kernel caches
|
||||
(inductor/triton) stay warm across diffusion steps.
|
||||
"""
|
||||
|
||||
def __init__(self, devices: list[torch.device]):
|
||||
self._workers: list[threading.Thread] = []
|
||||
self._work_queues: dict[torch.device, queue.Queue] = {}
|
||||
self._result_queues: dict[torch.device, queue.Queue] = {}
|
||||
|
||||
for device in devices:
|
||||
wq = queue.Queue()
|
||||
rq = queue.Queue()
|
||||
self._work_queues[device] = wq
|
||||
self._result_queues[device] = rq
|
||||
t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True)
|
||||
t.start()
|
||||
self._workers.append(t)
|
||||
|
||||
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
|
||||
try:
|
||||
torch.cuda.set_device(device)
|
||||
except Exception as e:
|
||||
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
|
||||
while True:
|
||||
item = work_q.get()
|
||||
if item is None:
|
||||
return
|
||||
result_q.put((None, e))
|
||||
return
|
||||
while True:
|
||||
item = work_q.get()
|
||||
if item is None:
|
||||
break
|
||||
fn, args, kwargs = item
|
||||
try:
|
||||
result = fn(*args, **kwargs)
|
||||
result_q.put((result, None))
|
||||
except Exception as e:
|
||||
result_q.put((None, e))
|
||||
|
||||
def submit(self, device: torch.device, fn, *args, **kwargs):
|
||||
self._work_queues[device].put((fn, args, kwargs))
|
||||
|
||||
def get_result(self, device: torch.device):
|
||||
return self._result_queues[device].get()
|
||||
|
||||
@property
|
||||
def devices(self) -> list[torch.device]:
|
||||
return list(self._work_queues.keys())
|
||||
|
||||
def shutdown(self):
|
||||
for wq in self._work_queues.values():
|
||||
wq.put(None) # sentinel
|
||||
for t in self._workers:
|
||||
t.join(timeout=5.0)
|
||||
|
||||
|
||||
class GPUOptions:
|
||||
def __init__(self, device_index: int, relative_speed: float):
|
||||
self.device_index = device_index
|
||||
self.relative_speed = relative_speed
|
||||
|
||||
def clone(self):
|
||||
return GPUOptions(self.device_index, self.relative_speed)
|
||||
|
||||
def create_dict(self):
|
||||
return {
|
||||
"relative_speed": self.relative_speed
|
||||
}
|
||||
|
||||
class GPUOptionsGroup:
|
||||
def __init__(self):
|
||||
self.options: dict[int, GPUOptions] = {}
|
||||
|
||||
def add(self, info: GPUOptions):
|
||||
self.options[info.device_index] = info
|
||||
|
||||
def clone(self):
|
||||
c = GPUOptionsGroup()
|
||||
for opt in self.options.values():
|
||||
c.add(opt)
|
||||
return c
|
||||
|
||||
def register(self, model: ModelPatcher):
|
||||
opts_dict = {}
|
||||
# get devices that are valid for this model
|
||||
devices: list[torch.device] = [model.load_device]
|
||||
for extra_model in model.get_additional_models_with_key("multigpu"):
|
||||
extra_model: ModelPatcher
|
||||
devices.append(extra_model.load_device)
|
||||
# create dictionary with actual device mapped to its GPUOptions
|
||||
device_opts_list: list[GPUOptions] = []
|
||||
for device in devices:
|
||||
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
|
||||
opts_dict[device] = device_opts.create_dict()
|
||||
device_opts_list.append(device_opts)
|
||||
# make relative_speed relative to 1.0
|
||||
min_speed = min([x.relative_speed for x in device_opts_list])
|
||||
for value in opts_dict.values():
|
||||
value['relative_speed'] /= min_speed
|
||||
model.model_options['multigpu_options'] = opts_dict
|
||||
|
||||
|
||||
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
|
||||
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
|
||||
model = model.clone()
|
||||
# check if multigpu is already prepared - get the load devices from them if possible to exclude
|
||||
skip_devices = set()
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) > 0:
|
||||
for mm in multigpu_models:
|
||||
skip_devices.add(mm.load_device)
|
||||
skip_devices = list(skip_devices)
|
||||
|
||||
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
|
||||
limit_extra_devices = full_extra_devices[:max_gpus-1]
|
||||
extra_devices = limit_extra_devices.copy()
|
||||
# exclude skipped devices
|
||||
for skip in skip_devices:
|
||||
if skip in extra_devices:
|
||||
extra_devices.remove(skip)
|
||||
# create new deepclones
|
||||
if len(extra_devices) > 0:
|
||||
for device in extra_devices:
|
||||
device_patcher = None
|
||||
if reuse_loaded:
|
||||
# check if there are any ModelPatchers currently loaded that could be referenced here after a clone
|
||||
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
|
||||
for lm in loaded_models:
|
||||
if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device:
|
||||
device_patcher = lm.clone()
|
||||
logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}")
|
||||
break
|
||||
if device_patcher is None:
|
||||
device_patcher = model.deepclone_multigpu(new_load_device=device)
|
||||
device_patcher.is_multigpu_base_clone = True
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
multigpu_models.append(device_patcher)
|
||||
model.set_additional_models("multigpu", multigpu_models)
|
||||
model.match_multigpu_clones()
|
||||
if gpu_options is None:
|
||||
gpu_options = GPUOptionsGroup()
|
||||
gpu_options.register(model)
|
||||
else:
|
||||
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
|
||||
# TODO: only keep model clones that don't go 'past' the intended max_gpu count
|
||||
# multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
# new_multigpu_models = []
|
||||
# for m in multigpu_models:
|
||||
# if m.load_device in limit_extra_devices:
|
||||
# new_multigpu_models.append(m)
|
||||
# model.set_additional_models("multigpu", new_multigpu_models)
|
||||
# persist skip_devices for use in sampling code
|
||||
# if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options:
|
||||
# model.model_options["multigpu_skip_devices"] = skip_devices
|
||||
return model
|
||||
|
||||
|
||||
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
|
||||
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
|
||||
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
|
||||
opts_dict = model_options['multigpu_options']
|
||||
devices = list(model_options['multigpu_clones'].keys())
|
||||
speed_per_device = []
|
||||
work_per_device = []
|
||||
# get sum of each device's relative_speed
|
||||
total_speed = 0.0
|
||||
for opts in opts_dict.values():
|
||||
total_speed += opts['relative_speed']
|
||||
# get relative work for each device;
|
||||
# obtained by w = (W*r)/R
|
||||
for device in devices:
|
||||
relative_speed = opts_dict[device]['relative_speed']
|
||||
relative_work = (total_work*relative_speed) / total_speed
|
||||
speed_per_device.append(relative_speed)
|
||||
work_per_device.append(relative_work)
|
||||
# relative work must be expressed in whole numbers, but likely is a decimal;
|
||||
# perform rounding while maintaining total sum equal to total work (sum of relative works)
|
||||
work_per_device = round_preserved(work_per_device)
|
||||
dict_work_per_device = {}
|
||||
for device, relative_work in zip(devices, work_per_device):
|
||||
dict_work_per_device[device] = relative_work
|
||||
if not return_idle_time:
|
||||
return LoadBalance(dict_work_per_device, None)
|
||||
# divide relative work by relative speed to get estimated completion time of said work by each device;
|
||||
# time here is relative and does not correspond to real-world units
|
||||
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
|
||||
# calculate relative time spent by the devices waiting on each other after their work is completed
|
||||
idle_time = abs(min(completion_time) - max(completion_time))
|
||||
# if need to compare work idle time, need to normalize to a common total work
|
||||
if work_normalized:
|
||||
idle_time *= (work_normalized/total_work)
|
||||
|
||||
return LoadBalance(dict_work_per_device, idle_time)
|
||||
|
||||
def round_preserved(values: list[float]):
|
||||
'Round all values in a list, preserving the combined sum of values.'
|
||||
# get floor of values; casting to int does it too
|
||||
floored = [int(x) for x in values]
|
||||
total_floored = sum(floored)
|
||||
# get remainder to distribute
|
||||
remainder = round(sum(values)) - total_floored
|
||||
# pair values with fractional portions
|
||||
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
|
||||
# sort by fractional part in descending order
|
||||
fractional.sort(key=lambda x: x[1], reverse=True)
|
||||
# distribute the remainder
|
||||
for i in range(remainder):
|
||||
index = fractional[i][0]
|
||||
floored[index] += 1
|
||||
return floored
|
||||
@@ -1151,7 +1151,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
if param is None:
|
||||
continue
|
||||
p = fn(param)
|
||||
if (not torch.is_inference_mode_enabled()) and p.is_inference():
|
||||
if p.is_inference():
|
||||
p = p.clone()
|
||||
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||
for key, buf in self._buffers.items():
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import Callable
|
||||
|
||||
class CallbacksMP:
|
||||
ON_CLONE = "on_clone"
|
||||
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
|
||||
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
|
||||
ON_LOAD = "on_load_after"
|
||||
ON_DETACH = "on_detach_after"
|
||||
ON_CLEANUP = "on_cleanup"
|
||||
|
||||
@@ -20,7 +20,6 @@ try:
|
||||
if cuda_version < (13,):
|
||||
ck.registry.disable("cuda")
|
||||
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
|
||||
|
||||
ck.registry.disable("triton")
|
||||
for k, v in ck.list_backends().items():
|
||||
logging.info(f"Found comfy_kitchen backend {k}: {v}")
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
import uuid
|
||||
import math
|
||||
import collections
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.model_patcher
|
||||
import comfy.utils
|
||||
import comfy.hooks
|
||||
import comfy.patcher_extension
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.controlnet import ControlBase
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
@@ -118,6 +120,47 @@ def cleanup_additional_models(models):
|
||||
if hasattr(m, 'cleanup'):
|
||||
m.cleanup()
|
||||
|
||||
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
|
||||
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
|
||||
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) == 0:
|
||||
return
|
||||
extra_devices = [x.load_device for x in multigpu_models]
|
||||
# handle controlnets
|
||||
controlnets: set[ControlBase] = set()
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
if 'control' in kk:
|
||||
controlnets.add(kk['control'])
|
||||
if len(controlnets) > 0:
|
||||
# first, unload all controlnet clones
|
||||
for cnet in list(controlnets):
|
||||
cnet_models = cnet.get_models()
|
||||
for cm in cnet_models:
|
||||
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
|
||||
|
||||
# next, make sure each controlnet has a deepclone for all relevant devices
|
||||
for cnet in controlnets:
|
||||
curr_cnet = cnet
|
||||
while curr_cnet is not None:
|
||||
for device in extra_devices:
|
||||
if device not in curr_cnet.multigpu_clones:
|
||||
curr_cnet.deepclone_multigpu(device, autoregister=True)
|
||||
curr_cnet = curr_cnet.previous_controlnet
|
||||
# since all device clones are now present, recreate the linked list for cloned cnets per device
|
||||
for cnet in controlnets:
|
||||
curr_cnet = cnet
|
||||
while curr_cnet is not None:
|
||||
prev_cnet = curr_cnet.previous_controlnet
|
||||
for device in extra_devices:
|
||||
device_cnet = curr_cnet.get_instance_for_device(device)
|
||||
prev_device_cnet = None
|
||||
if prev_cnet is not None:
|
||||
prev_device_cnet = prev_cnet.get_instance_for_device(device)
|
||||
device_cnet.set_previous_controlnet(prev_device_cnet)
|
||||
curr_cnet = prev_cnet
|
||||
# potentially handle gligen - since not widely used, ignored for now
|
||||
|
||||
def estimate_memory(model, noise_shape, conds):
|
||||
cond_shapes = collections.defaultdict(list)
|
||||
cond_shapes_min = {}
|
||||
@@ -142,7 +185,8 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
|
||||
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||
real_model: BaseModel = None
|
||||
model.match_multigpu_clones()
|
||||
preprocess_multigpu_conds(conds, model, model_options)
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
@@ -154,7 +198,7 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
|
||||
memory_required += inference_memory
|
||||
minimum_memory_required += inference_memory
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||
real_model = model.model
|
||||
real_model: BaseModel = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
||||
@@ -200,3 +244,18 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
|
||||
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
||||
copy_dict1=False)
|
||||
return to_load_options
|
||||
|
||||
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
|
||||
'''
|
||||
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
|
||||
'''
|
||||
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
|
||||
if len(multigpu_patchers) > 0:
|
||||
multigpu_dict: dict[torch.device, ModelPatcher] = {}
|
||||
multigpu_dict[model_patcher.load_device] = model_patcher
|
||||
for x in multigpu_patchers:
|
||||
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
|
||||
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
|
||||
multigpu_dict[x.load_device] = x
|
||||
model_options["multigpu_clones"] = multigpu_dict
|
||||
return multigpu_patchers
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import comfy.model_management
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .extra_samplers import uni_pc
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple, Any
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
@@ -16,6 +18,7 @@ import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import comfy.context_windows
|
||||
import comfy.multigpu
|
||||
import comfy.utils
|
||||
import scipy.stats
|
||||
import numpy
|
||||
@@ -141,7 +144,7 @@ def can_concat_cond(c1, c2):
|
||||
|
||||
return cond_equal_size(c1.conditioning, c2.conditioning)
|
||||
|
||||
def cond_cat(c_list):
|
||||
def cond_cat(c_list, device=None):
|
||||
temp = {}
|
||||
for x in c_list:
|
||||
for k in x:
|
||||
@@ -153,6 +156,8 @@ def cond_cat(c_list):
|
||||
for k in temp:
|
||||
conds = temp[k]
|
||||
out[k] = conds[0].concat(conds[1:])
|
||||
if device is not None and hasattr(out[k], 'to'):
|
||||
out[k] = out[k].to(device)
|
||||
|
||||
return out
|
||||
|
||||
@@ -212,7 +217,9 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
if 'multigpu_clones' in model_options:
|
||||
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
@@ -244,7 +251,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||
|
||||
model.current_patcher.prepare_state(timestep)
|
||||
model.current_patcher.prepare_state(timestep, model_options)
|
||||
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
@@ -345,6 +352,212 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
|
||||
return out_conds
|
||||
|
||||
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
|
||||
default_conds = []
|
||||
has_default_conds = False
|
||||
|
||||
output_device = x_in.device
|
||||
|
||||
for i in range(len(conds)):
|
||||
out_conds.append(torch.zeros_like(x_in))
|
||||
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||
|
||||
cond = conds[i]
|
||||
default_c = []
|
||||
if cond is not None:
|
||||
for x in cond:
|
||||
if 'default' in x:
|
||||
default_c.append(x)
|
||||
has_default_conds = True
|
||||
continue
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
if p.hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
default_conds.append(default_c)
|
||||
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||
|
||||
model.current_patcher.prepare_state(timestep, model_options)
|
||||
|
||||
devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()]
|
||||
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
|
||||
|
||||
total_conds = 0
|
||||
for to_run in hooked_to_run.values():
|
||||
total_conds += len(to_run)
|
||||
conds_per_device = max(1, math.ceil(total_conds//len(devices)))
|
||||
index_device = 0
|
||||
current_device = devices[index_device]
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
while len(to_run) > 0:
|
||||
current_device = devices[index_device % len(devices)]
|
||||
batched_to_run = device_batched_hooked_to_run.setdefault(current_device, [])
|
||||
# keep track of conds currently scheduled onto this device
|
||||
batched_to_run_length = 0
|
||||
for btr in batched_to_run:
|
||||
batched_to_run_length += len(btr[1])
|
||||
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
# make sure not over conds_per_device limit when creating temp batch
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < (conds_per_device - batched_to_run_length):
|
||||
to_batch_temp += [x]
|
||||
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
free_memory = comfy.model_management.get_free_memory(current_device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
conds_to_batch = []
|
||||
for x in to_batch:
|
||||
conds_to_batch.append(to_run.pop(x))
|
||||
batched_to_run_length += len(conds_to_batch)
|
||||
|
||||
batched_to_run.append((hooks, conds_to_batch))
|
||||
if batched_to_run_length >= conds_per_device:
|
||||
index_device += 1
|
||||
|
||||
class thread_result(NamedTuple):
|
||||
output: Any
|
||||
mult: Any
|
||||
area: Any
|
||||
batch_chunks: int
|
||||
cond_or_uncond: Any
|
||||
error: Exception = None
|
||||
|
||||
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
|
||||
try:
|
||||
torch.cuda.set_device(device)
|
||||
model_current: BaseModel = model_options["multigpu_clones"][device].model
|
||||
# run every hooked_to_run separately
|
||||
with torch.no_grad():
|
||||
for hooks, to_batch in batch_tuple:
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
uuids = []
|
||||
area = []
|
||||
control: ControlBase = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = x
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
uuids.append(p.uuid)
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x).to(device)
|
||||
c = cond_cat(c, device=device)
|
||||
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
|
||||
|
||||
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
||||
model_options['transformer_options'],
|
||||
copy_dict1=False)
|
||||
|
||||
if patches is not None:
|
||||
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
|
||||
transformer_options.get("patches", {}),
|
||||
patches
|
||||
)
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["uuids"] = uuids[:]
|
||||
transformer_options["sigmas"] = timestep.to(device)
|
||||
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
|
||||
transformer_options["multigpu_thread_device"] = device
|
||||
|
||||
cast_transformer_options(transformer_options, device=device)
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if control is not None:
|
||||
device_control = control.get_instance_for_device(device)
|
||||
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
|
||||
else:
|
||||
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
|
||||
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
|
||||
except Exception as e:
|
||||
results.append(thread_result(None, None, None, None, None, error=e))
|
||||
raise
|
||||
|
||||
|
||||
def _handle_batch_pooled(device, batch_tuple):
|
||||
worker_results = []
|
||||
_handle_batch(device, batch_tuple, worker_results)
|
||||
return worker_results
|
||||
|
||||
results: list[thread_result] = []
|
||||
thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool")
|
||||
|
||||
# Submit all GPU work to pool threads
|
||||
pool_devices = []
|
||||
for device, batch_tuple in device_batched_hooked_to_run.items():
|
||||
if thread_pool is not None:
|
||||
thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple)
|
||||
pool_devices.append(device)
|
||||
else:
|
||||
# Fallback: no pool, run everything on main thread
|
||||
_handle_batch(device, batch_tuple, results)
|
||||
|
||||
# Collect results from pool workers
|
||||
for device in pool_devices:
|
||||
worker_results, error = thread_pool.get_result(device)
|
||||
if error is not None:
|
||||
raise error
|
||||
results.extend(worker_results)
|
||||
|
||||
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
|
||||
if error is not None:
|
||||
raise error
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
dims = len(a) // 2
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
|
||||
return out_conds
|
||||
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
||||
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
||||
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
||||
@@ -649,6 +862,8 @@ def pre_run_control(model, conds):
|
||||
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||
if 'control' in x:
|
||||
x['control'].pre_run(model, percent_to_timestep_function)
|
||||
for device_cnet in x['control'].multigpu_clones.values():
|
||||
device_cnet.pre_run(model, percent_to_timestep_function)
|
||||
|
||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||
cond_cnets = []
|
||||
@@ -891,7 +1106,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
to_load_options = model_options.get("to_load_options", None)
|
||||
if to_load_options is None:
|
||||
return
|
||||
cast_transformer_options(to_load_options, device, dtype)
|
||||
|
||||
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
|
||||
casts = []
|
||||
if device is not None:
|
||||
casts.append(device)
|
||||
@@ -900,18 +1117,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
# if nothing to apply, do nothing
|
||||
if len(casts) == 0:
|
||||
return
|
||||
|
||||
# try to call .to on patches
|
||||
if "patches" in to_load_options:
|
||||
patches = to_load_options["patches"]
|
||||
if "patches" in transformer_options:
|
||||
patches = transformer_options["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], "to"):
|
||||
for cast in casts:
|
||||
patch_list[i] = patch_list[i].to(cast)
|
||||
if "patches_replace" in to_load_options:
|
||||
patches = to_load_options["patches_replace"]
|
||||
if "patches_replace" in transformer_options:
|
||||
patches = transformer_options["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
@@ -921,8 +1137,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
# try to call .to on any wrappers/callbacks
|
||||
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
||||
for wc_name in wrappers_and_callbacks:
|
||||
if wc_name in to_load_options:
|
||||
wc: dict[str, list] = to_load_options[wc_name]
|
||||
if wc_name in transformer_options:
|
||||
wc: dict[str, list] = transformer_options[wc_name]
|
||||
for wc_dict in wc.values():
|
||||
for wc_list in wc_dict.values():
|
||||
for i in range(len(wc_list)):
|
||||
@@ -930,7 +1146,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
for cast in casts:
|
||||
wc_list[i] = wc_list[i].to(cast)
|
||||
|
||||
|
||||
class CFGGuider:
|
||||
def __init__(self, model_patcher: ModelPatcher):
|
||||
self.model_patcher = model_patcher
|
||||
@@ -985,16 +1200,31 @@ class CFGGuider:
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
noise = noise.to(device=device, dtype=torch.float32)
|
||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||
sigmas = sigmas.to(device)
|
||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
|
||||
|
||||
# Create persistent thread pool for all GPU devices (main + extras)
|
||||
if multigpu_patchers:
|
||||
extra_devices = [p.load_device for p in multigpu_patchers]
|
||||
all_devices = [device] + extra_devices
|
||||
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
|
||||
|
||||
try:
|
||||
noise = noise.to(device=device, dtype=torch.float32)
|
||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||
sigmas = sigmas.to(device)
|
||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||
|
||||
self.model_patcher.pre_run()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
finally:
|
||||
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
|
||||
if thread_pool is not None:
|
||||
thread_pool.shutdown()
|
||||
self.model_patcher.cleanup()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.cleanup()
|
||||
|
||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||
del self.inner_model
|
||||
|
||||
13
comfy/sd.py
13
comfy/sd.py
@@ -62,7 +62,6 @@ import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.qwen35
|
||||
import comfy.text_encoders.ernie
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -1236,7 +1235,6 @@ class TEModel(Enum):
|
||||
QWEN35_4B = 25
|
||||
QWEN35_9B = 26
|
||||
QWEN35_27B = 27
|
||||
MINISTRAL_3_3B = 28
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@@ -1303,8 +1301,6 @@ def detect_te_model(sd):
|
||||
return TEModel.MISTRAL3_24B
|
||||
else:
|
||||
return TEModel.MISTRAL3_24B_PRUNED_FLUX2
|
||||
if weight.shape[0] == 3072:
|
||||
return TEModel.MINISTRAL_3_3B
|
||||
|
||||
return TEModel.LLAMA3_8
|
||||
return None
|
||||
@@ -1462,10 +1458,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif te_model == TEModel.QWEN3_06B:
|
||||
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
|
||||
elif te_model == TEModel.MINISTRAL_3_3B:
|
||||
clip_target.clip = comfy.text_encoders.ernie.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.ernie.ErnieTokenizer
|
||||
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
|
||||
else:
|
||||
# clip_l
|
||||
if clip_type == CLIPType.SD3:
|
||||
@@ -1604,10 +1596,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||
if out is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||
if output_model and out[0] is not None:
|
||||
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||
if output_clip and out[1] is not None:
|
||||
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
|
||||
return out
|
||||
|
||||
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
|
||||
@@ -26,7 +26,6 @@ import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.ernie
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -1750,37 +1749,6 @@ class RT_DETR_v4(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
|
||||
class ErnieImage(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "ernie",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1000.0,
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 10.0
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux2
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.ErnieImage(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}ministral3_3b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
from .flux import Mistral3Tokenizer
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.llama
|
||||
|
||||
class Ministral3_3BTokenizer(Mistral3Tokenizer):
|
||||
def __init__(self, embedding_directory=None, embedding_size=5120, embedding_key='ministral3_3b', tokenizer_data={}):
|
||||
return super().__init__(embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_data=tokenizer_data)
|
||||
|
||||
class ErnieTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="ministral3_3b", tokenizer=Mistral3Tokenizer)
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||
tokens = super().tokenize_with_weights(text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
return tokens
|
||||
|
||||
|
||||
class Ministral3_3BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||
textmodel_json_config = {}
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ministral3_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class ErnieTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, name="ministral3_3b", clip_model=Ministral3_3BModel):
|
||||
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class ErnieTEModel_(ErnieTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return ErnieTEModel
|
||||
@@ -116,9 +116,9 @@ class MistralTokenizerClass:
|
||||
return LlamaTokenizerFast(**kwargs)
|
||||
|
||||
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_data={}):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.tekken_data = tokenizer_data.get("tekken_model", None)
|
||||
super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
|
||||
super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"tekken_model": self.tekken_data}
|
||||
|
||||
@@ -60,30 +60,6 @@ class Mistral3Small24BConfig:
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Ministral3_3BConfig:
|
||||
vocab_size: int = 131072
|
||||
hidden_size: int = 3072
|
||||
intermediate_size: int = 9216
|
||||
num_hidden_layers: int = 26
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 262144
|
||||
rms_norm_eps: float = 1e-5
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = None
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
stop_tokens = [2]
|
||||
|
||||
@dataclass
|
||||
class Qwen25_3BConfig:
|
||||
vocab_size: int = 151936
|
||||
@@ -970,15 +946,6 @@ class Mistral3Small24B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Ministral3_3B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Ministral3_3BConfig(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
@@ -52,26 +52,6 @@ class TaskImageContent(BaseModel):
|
||||
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
|
||||
|
||||
|
||||
class TaskVideoContentUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class TaskVideoContent(BaseModel):
|
||||
type: str = Field("video_url")
|
||||
video_url: TaskVideoContentUrl = Field(...)
|
||||
role: str = Field("reference_video")
|
||||
|
||||
|
||||
class TaskAudioContentUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class TaskAudioContent(BaseModel):
|
||||
type: str = Field("audio_url")
|
||||
audio_url: TaskAudioContentUrl = Field(...)
|
||||
role: str = Field("reference_audio")
|
||||
|
||||
|
||||
class Text2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
content: list[TaskTextContent] = Field(..., min_length=1)
|
||||
@@ -84,17 +64,6 @@ class Image2VideoTaskCreationRequest(BaseModel):
|
||||
generate_audio: bool | None = Field(...)
|
||||
|
||||
|
||||
class Seedance2TaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = Field(..., min_length=1)
|
||||
generate_audio: bool | None = Field(None)
|
||||
resolution: str | None = Field(None)
|
||||
ratio: str | None = Field(None)
|
||||
duration: int | None = Field(None, ge=4, le=15)
|
||||
seed: int | None = Field(None, ge=0, le=2147483647)
|
||||
watermark: bool | None = Field(None)
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
|
||||
@@ -108,27 +77,12 @@ class TaskStatusResult(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusUsage(BaseModel):
|
||||
completion_tokens: int = Field(0)
|
||||
total_tokens: int = Field(0)
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
model: str = Field(...)
|
||||
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
|
||||
error: TaskStatusError | None = Field(None)
|
||||
content: TaskStatusResult | None = Field(None)
|
||||
usage: TaskStatusUsage | None = Field(None)
|
||||
|
||||
|
||||
# Dollars per 1K tokens, keyed by (model_id, has_video_input).
|
||||
SEEDANCE2_PRICE_PER_1K_TOKENS = {
|
||||
("dreamina-seedance-2-0-260128", False): 0.007,
|
||||
("dreamina-seedance-2-0-260128", True): 0.0043,
|
||||
("dreamina-seedance-2-0-fast-260128", False): 0.0056,
|
||||
("dreamina-seedance-2-0-fast-260128", True): 0.0033,
|
||||
}
|
||||
|
||||
|
||||
RECOMMENDED_PRESETS = [
|
||||
@@ -158,12 +112,6 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
# Seedance 2.0 reference video pixel count limits per model.
|
||||
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
|
||||
"dreamina-seedance-2-0-260128": {"min": 409_600, "max": 927_408},
|
||||
"dreamina-seedance-2-0-fast-260128": {"min": 409_600, "max": 927_408},
|
||||
}
|
||||
|
||||
# The time in this dictionary are given for 10 seconds duration.
|
||||
VIDEO_TASKS_EXECUTION_TIME = {
|
||||
"seedance-1-0-lite-t2v-250428": {
|
||||
|
||||
@@ -8,23 +8,16 @@ from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bytedance import (
|
||||
RECOMMENDED_PRESETS,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4,
|
||||
SEEDANCE2_PRICE_PER_1K_TOKENS,
|
||||
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS,
|
||||
VIDEO_TASKS_EXECUTION_TIME,
|
||||
Image2VideoTaskCreationRequest,
|
||||
ImageTaskCreationResponse,
|
||||
Seedance2TaskCreationRequest,
|
||||
Seedream4Options,
|
||||
Seedream4TaskCreationRequest,
|
||||
TaskAudioContent,
|
||||
TaskAudioContentUrl,
|
||||
TaskCreationResponse,
|
||||
TaskImageContent,
|
||||
TaskImageContentUrl,
|
||||
TaskStatusResponse,
|
||||
TaskTextContent,
|
||||
TaskVideoContent,
|
||||
TaskVideoContentUrl,
|
||||
Text2ImageTaskCreationRequest,
|
||||
Text2VideoTaskCreationRequest,
|
||||
)
|
||||
@@ -36,10 +29,7 @@ from comfy_api_nodes.util import (
|
||||
image_tensor_pair_to_batch,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_audio_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
@@ -56,56 +46,12 @@ SEEDREAM_MODELS = {
|
||||
# Long-running tasks endpoints(e.g., video)
|
||||
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
||||
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
||||
BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT = "/proxy/byteplus-seedance2/api/v3/contents/generations/tasks" # + /{task_id}
|
||||
|
||||
SEEDANCE_MODELS = {
|
||||
"Seedance 2.0": "dreamina-seedance-2-0-260128",
|
||||
"Seedance 2.0 Fast": "dreamina-seedance-2-0-fast-260128",
|
||||
}
|
||||
|
||||
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_ref_video_pixels(video: Input.Video, model_id: str, index: int) -> None:
|
||||
"""Validate reference video pixel count against Seedance 2.0 model limits."""
|
||||
limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id)
|
||||
if not limits:
|
||||
return
|
||||
try:
|
||||
w, h = video.get_dimensions()
|
||||
except Exception:
|
||||
return
|
||||
pixels = w * h
|
||||
min_px = limits.get("min")
|
||||
max_px = limits.get("max")
|
||||
if min_px and pixels < min_px:
|
||||
raise ValueError(
|
||||
f"Reference video {index} is too small: {w}x{h} = {pixels:,}px. " f"Minimum is {min_px:,}px for this model."
|
||||
)
|
||||
if max_px and pixels > max_px:
|
||||
raise ValueError(
|
||||
f"Reference video {index} is too large: {w}x{h} = {pixels:,}px. "
|
||||
f"Maximum is {max_px:,}px for this model. Try downscaling the video."
|
||||
)
|
||||
|
||||
|
||||
def _seedance2_price_extractor(model_id: str, has_video_input: bool):
|
||||
"""Returns a price_extractor closure for Seedance 2.0 poll_op."""
|
||||
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input))
|
||||
if rate is None:
|
||||
return None
|
||||
|
||||
def extractor(response: TaskStatusResponse) -> float | None:
|
||||
if response.usage is None:
|
||||
return None
|
||||
return response.usage.total_tokens * 1.43 * rate / 1_000.0
|
||||
|
||||
return extractor
|
||||
|
||||
|
||||
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
||||
if response.error:
|
||||
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
|
||||
@@ -389,7 +335,8 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
mp_provided = out_num_pixels / 1_000_000.0
|
||||
if ("seedream-4-5" in model or "seedream-5-0" in model) and out_num_pixels < 3686400:
|
||||
raise ValueError(
|
||||
f"Minimum image resolution for the selected model is 3.68MP, " f"but {mp_provided:.2f}MP provided."
|
||||
f"Minimum image resolution for the selected model is 3.68MP, "
|
||||
f"but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
if "seedream-4-0" in model and out_num_pixels < 921600:
|
||||
raise ValueError(
|
||||
@@ -1005,6 +952,33 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
async def process_video_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
||||
estimated_duration: int | None,
|
||||
) -> IO.NodeOutput:
|
||||
if payload.model in DEPRECATED_MODELS:
|
||||
logger.warning(
|
||||
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
|
||||
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
|
||||
payload.model,
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||
data=payload,
|
||||
response_model=TaskCreationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
estimated_duration=estimated_duration,
|
||||
response_model=TaskStatusResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:
|
||||
for i in text_params:
|
||||
if f"--{i} " in prompt:
|
||||
@@ -1066,530 +1040,6 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
|
||||
)
|
||||
|
||||
|
||||
def _seedance2_text_inputs():
|
||||
return [
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for video generation.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["480p", "720p"],
|
||||
tooltip="Resolution of the output video.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"ratio",
|
||||
options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"],
|
||||
tooltip="Aspect ratio of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=7,
|
||||
min=4,
|
||||
max=15,
|
||||
step=1,
|
||||
tooltip="Duration of the output video in seconds (4-15).",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=True,
|
||||
tooltip="Enable audio generation for the output video.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class ByteDance2TextToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ByteDance2TextToVideoNode",
|
||||
display_name="ByteDance Seedance 2.0 Text to Video",
|
||||
category="api node/video/ByteDance",
|
||||
description="Generate video using Seedance 2.0 models based on a text prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs()),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs()),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=False,
|
||||
tooltip="Whether to add a watermark to the video.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution", "model.duration"]),
|
||||
expr="""
|
||||
(
|
||||
$rate480 := 10044;
|
||||
$rate720 := 21600;
|
||||
$m := widgets.model;
|
||||
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$rate := $res = "720p" ? $rate720 : $rate480;
|
||||
$cost := $dur * $rate * $pricePer1K / 1000;
|
||||
{"type": "usd", "usd": $cost, "format": {"approximate": true}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: dict,
|
||||
seed: int,
|
||||
watermark: bool,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(model["prompt"], strip_whitespace=True, min_length=1)
|
||||
model_id = SEEDANCE_MODELS[model["model"]]
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||
data=Seedance2TaskCreationRequest(
|
||||
model=model_id,
|
||||
content=[TaskTextContent(text=model["prompt"])],
|
||||
generate_audio=model["generate_audio"],
|
||||
resolution=model["resolution"],
|
||||
ratio=model["ratio"],
|
||||
duration=model["duration"],
|
||||
seed=seed,
|
||||
watermark=watermark,
|
||||
),
|
||||
response_model=TaskCreationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
|
||||
poll_interval=9,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ByteDance2FirstLastFrameNode",
|
||||
display_name="ByteDance Seedance 2.0 First-Last-Frame to Video",
|
||||
category="api node/video/ByteDance",
|
||||
description="Generate video using Seedance 2.0 from a first frame image and optional last frame image.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs()),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs()),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"first_frame",
|
||||
tooltip="First frame image for the video.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"last_frame",
|
||||
tooltip="Last frame image for the video.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=False,
|
||||
tooltip="Whether to add a watermark to the video.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution", "model.duration"]),
|
||||
expr="""
|
||||
(
|
||||
$rate480 := 10044;
|
||||
$rate720 := 21600;
|
||||
$m := widgets.model;
|
||||
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$rate := $res = "720p" ? $rate720 : $rate480;
|
||||
$cost := $dur * $rate * $pricePer1K / 1000;
|
||||
{"type": "usd", "usd": $cost, "format": {"approximate": true}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: dict,
|
||||
first_frame: Input.Image,
|
||||
seed: int,
|
||||
watermark: bool,
|
||||
last_frame: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(model["prompt"], strip_whitespace=True, min_length=1)
|
||||
model_id = SEEDANCE_MODELS[model["model"]]
|
||||
|
||||
content: list[TaskTextContent | TaskImageContent] = [
|
||||
TaskTextContent(text=model["prompt"]),
|
||||
TaskImageContent(
|
||||
image_url=TaskImageContentUrl(
|
||||
url=await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.")
|
||||
),
|
||||
role="first_frame",
|
||||
),
|
||||
]
|
||||
if last_frame is not None:
|
||||
content.append(
|
||||
TaskImageContent(
|
||||
image_url=TaskImageContentUrl(
|
||||
url=await upload_image_to_comfyapi(cls, last_frame, wait_label="Uploading last frame.")
|
||||
),
|
||||
role="last_frame",
|
||||
),
|
||||
)
|
||||
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||
data=Seedance2TaskCreationRequest(
|
||||
model=model_id,
|
||||
content=content,
|
||||
generate_audio=model["generate_audio"],
|
||||
resolution=model["resolution"],
|
||||
ratio=model["ratio"],
|
||||
duration=model["duration"],
|
||||
seed=seed,
|
||||
watermark=watermark,
|
||||
),
|
||||
response_model=TaskCreationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
|
||||
poll_interval=9,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
def _seedance2_reference_inputs():
|
||||
return [
|
||||
*_seedance2_text_inputs(),
|
||||
IO.Autogrow.Input(
|
||||
"reference_images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("reference_image"),
|
||||
names=[
|
||||
"image_1",
|
||||
"image_2",
|
||||
"image_3",
|
||||
"image_4",
|
||||
"image_5",
|
||||
"image_6",
|
||||
"image_7",
|
||||
"image_8",
|
||||
"image_9",
|
||||
],
|
||||
min=0,
|
||||
),
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"reference_videos",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Video.Input("reference_video"),
|
||||
names=["video_1", "video_2", "video_3"],
|
||||
min=0,
|
||||
),
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"reference_audios",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Audio.Input("reference_audio"),
|
||||
names=["audio_1", "audio_2", "audio_3"],
|
||||
min=0,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ByteDance2ReferenceNode",
|
||||
display_name="ByteDance Seedance 2.0 Reference to Video",
|
||||
category="api node/video/ByteDance",
|
||||
description="Generate, edit, or extend video using Seedance 2.0 with reference images, "
|
||||
"videos, and audio. Supports multimodal reference, video editing, and video extension.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs()),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs()),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=False,
|
||||
tooltip="Whether to add a watermark to the video.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model", "model.resolution", "model.duration"],
|
||||
input_groups=["model.reference_videos"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$rate480 := 10044;
|
||||
$rate720 := 21600;
|
||||
$m := widgets.model;
|
||||
$hasVideo := $lookup(inputGroups, "model.reference_videos") > 0;
|
||||
$noVideoPricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
||||
$videoPricePer1K := $contains($m, "fast") ? 0.004719 : 0.006149;
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$rate := $res = "720p" ? $rate720 : $rate480;
|
||||
$noVideoCost := $dur * $rate * $noVideoPricePer1K / 1000;
|
||||
$minVideoFactor := $ceil($dur * 5 / 3);
|
||||
$minVideoCost := $minVideoFactor * $rate * $videoPricePer1K / 1000;
|
||||
$maxVideoCost := (15 + $dur) * $rate * $videoPricePer1K / 1000;
|
||||
$hasVideo
|
||||
? {
|
||||
"type": "range_usd",
|
||||
"min_usd": $minVideoCost,
|
||||
"max_usd": $maxVideoCost,
|
||||
"format": {"approximate": true}
|
||||
}
|
||||
: {
|
||||
"type": "usd",
|
||||
"usd": $noVideoCost,
|
||||
"format": {"approximate": true}
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: dict,
|
||||
seed: int,
|
||||
watermark: bool,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(model["prompt"], strip_whitespace=True, min_length=1)
|
||||
|
||||
reference_images = model.get("reference_images", {})
|
||||
reference_videos = model.get("reference_videos", {})
|
||||
reference_audios = model.get("reference_audios", {})
|
||||
|
||||
if not reference_images and not reference_videos:
|
||||
raise ValueError("At least one reference image or video is required.")
|
||||
|
||||
model_id = SEEDANCE_MODELS[model["model"]]
|
||||
has_video_input = len(reference_videos) > 0
|
||||
total_video_duration = 0.0
|
||||
for i, key in enumerate(reference_videos, 1):
|
||||
video = reference_videos[key]
|
||||
_validate_ref_video_pixels(video, model_id, i)
|
||||
try:
|
||||
dur = video.get_duration()
|
||||
if dur < 1.8:
|
||||
raise ValueError(f"Reference video {i} is too short: {dur:.1f}s. Minimum duration is 1.8 seconds.")
|
||||
total_video_duration += dur
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
if total_video_duration > 15.1:
|
||||
raise ValueError(f"Total reference video duration is {total_video_duration:.1f}s. Maximum is 15.1 seconds.")
|
||||
|
||||
total_audio_duration = 0.0
|
||||
for i, key in enumerate(reference_audios, 1):
|
||||
audio = reference_audios[key]
|
||||
dur = int(audio["waveform"].shape[-1]) / int(audio["sample_rate"])
|
||||
if dur < 1.8:
|
||||
raise ValueError(f"Reference audio {i} is too short: {dur:.1f}s. Minimum duration is 1.8 seconds.")
|
||||
total_audio_duration += dur
|
||||
if total_audio_duration > 15.1:
|
||||
raise ValueError(f"Total reference audio duration is {total_audio_duration:.1f}s. Maximum is 15.1 seconds.")
|
||||
|
||||
content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = [
|
||||
TaskTextContent(text=model["prompt"]),
|
||||
]
|
||||
for i, key in enumerate(reference_images, 1):
|
||||
content.append(
|
||||
TaskImageContent(
|
||||
image_url=TaskImageContentUrl(
|
||||
url=await upload_image_to_comfyapi(
|
||||
cls,
|
||||
image=reference_images[key],
|
||||
wait_label=f"Uploading image {i}",
|
||||
),
|
||||
),
|
||||
role="reference_image",
|
||||
),
|
||||
)
|
||||
for i, key in enumerate(reference_videos, 1):
|
||||
content.append(
|
||||
TaskVideoContent(
|
||||
video_url=TaskVideoContentUrl(
|
||||
url=await upload_video_to_comfyapi(
|
||||
cls,
|
||||
reference_videos[key],
|
||||
wait_label=f"Uploading video {i}",
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
for key in reference_audios:
|
||||
content.append(
|
||||
TaskAudioContent(
|
||||
audio_url=TaskAudioContentUrl(
|
||||
url=await upload_audio_to_comfyapi(
|
||||
cls,
|
||||
reference_audios[key],
|
||||
container_format="mp3",
|
||||
codec_name="libmp3lame",
|
||||
mime_type="audio/mpeg",
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||
data=Seedance2TaskCreationRequest(
|
||||
model=model_id,
|
||||
content=content,
|
||||
generate_audio=model["generate_audio"],
|
||||
resolution=model["resolution"],
|
||||
ratio=model["ratio"],
|
||||
duration=model["duration"],
|
||||
seed=seed,
|
||||
watermark=watermark,
|
||||
),
|
||||
response_model=TaskCreationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
price_extractor=_seedance2_price_extractor(model_id, has_video_input=has_video_input),
|
||||
poll_interval=9,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
async def process_video_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
||||
estimated_duration: int | None,
|
||||
) -> IO.NodeOutput:
|
||||
if payload.model in DEPRECATED_MODELS:
|
||||
logger.warning(
|
||||
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
|
||||
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
|
||||
payload.model,
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||
data=payload,
|
||||
response_model=TaskCreationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
estimated_duration=estimated_duration,
|
||||
response_model=TaskStatusResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
class ByteDanceExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@@ -1600,9 +1050,6 @@ class ByteDanceExtension(ComfyExtension):
|
||||
ByteDanceImageToVideoNode,
|
||||
ByteDanceFirstLastFrameNode,
|
||||
ByteDanceImageReferenceNode,
|
||||
ByteDance2TextToVideoNode,
|
||||
ByteDance2FirstLastFrameNode,
|
||||
ByteDance2ReferenceNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -558,7 +558,7 @@ class GrokVideoReferenceNode(IO.ComfyNode):
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$refs := $lookup(inputGroups, "model.reference_images");
|
||||
$refs := inputGroups["model.reference_images"];
|
||||
$rate := $res = "720p" ? 0.07 : 0.05;
|
||||
$price := ($rate * $dur + 0.002 * $refs) * 1.43;
|
||||
{"type":"usd","usd": $price}
|
||||
|
||||
@@ -1,287 +0,0 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import aiohttp
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
audio_bytes_to_audio_input,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
from comfy_api_nodes.util._helpers import (
|
||||
default_base_url,
|
||||
get_auth_header,
|
||||
get_node_id,
|
||||
is_processing_interrupted,
|
||||
)
|
||||
from comfy_api_nodes.util.common_exceptions import ProcessingInterrupted
|
||||
from server import PromptServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SoniloVideoToMusic(IO.ComfyNode):
|
||||
"""Generate music from video using Sonilo's AI model."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="SoniloVideoToMusic",
|
||||
display_name="Sonilo Video to Music",
|
||||
category="api node/audio/Sonilo",
|
||||
description="Generate music from video content using Sonilo's AI model. "
|
||||
"Analyzes the video and creates matching music.",
|
||||
inputs=[
|
||||
IO.Video.Input(
|
||||
"video",
|
||||
tooltip="Input video to generate music from. Maximum duration: 6 minutes.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
default="",
|
||||
multiline=True,
|
||||
tooltip="Optional text prompt to guide music generation. "
|
||||
"Leave empty for best quality - the model will fully analyze the video content.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for reproducibility. Currently ignored by the Sonilo "
|
||||
"service but kept for graph consistency.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr='{"type":"usd","usd":0.009,"format":{"suffix":"/second"}}',
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
prompt: str = "",
|
||||
seed: int = 0,
|
||||
) -> IO.NodeOutput:
|
||||
video_url = await upload_video_to_comfyapi(cls, video, max_duration=360)
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("video_url", video_url)
|
||||
if prompt.strip():
|
||||
form.add_field("prompt", prompt.strip())
|
||||
audio_bytes = await _stream_sonilo_music(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/sonilo/v2m/generate", method="POST"),
|
||||
form,
|
||||
)
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(audio_bytes))
|
||||
|
||||
|
||||
class SoniloTextToMusic(IO.ComfyNode):
|
||||
"""Generate music from a text prompt using Sonilo's AI model."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="SoniloTextToMusic",
|
||||
display_name="Sonilo Text to Music",
|
||||
category="api node/audio/Sonilo",
|
||||
description="Generate music from a text prompt using Sonilo's AI model. "
|
||||
"Leave duration at 0 to let the model infer it from the prompt.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
default="",
|
||||
multiline=True,
|
||||
tooltip="Text prompt describing the music to generate.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=0,
|
||||
min=0,
|
||||
max=360,
|
||||
tooltip="Target duration in seconds. Set to 0 to let the model "
|
||||
"infer the duration from the prompt. Maximum: 6 minutes.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for reproducibility. Currently ignored by the Sonilo "
|
||||
"service but kept for graph consistency.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration"]),
|
||||
expr="""
|
||||
(
|
||||
widgets.duration > 0
|
||||
? {"type":"usd","usd": 0.005 * widgets.duration}
|
||||
: {"type":"usd","usd": 0.005, "format":{"suffix":"/second"}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
duration: int = 0,
|
||||
seed: int = 0,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("prompt", prompt)
|
||||
if duration > 0:
|
||||
form.add_field("duration", str(duration))
|
||||
audio_bytes = await _stream_sonilo_music(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/sonilo/t2m/generate", method="POST"),
|
||||
form,
|
||||
)
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(audio_bytes))
|
||||
|
||||
|
||||
async def _stream_sonilo_music(
|
||||
cls: type[IO.ComfyNode],
|
||||
endpoint: ApiEndpoint,
|
||||
form: aiohttp.FormData,
|
||||
) -> bytes:
|
||||
"""POST ``form`` to Sonilo, read the NDJSON stream, and return the first stream's audio bytes."""
|
||||
url = urljoin(default_base_url().rstrip("/") + "/", endpoint.path.lstrip("/"))
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
headers.update(get_auth_header(cls))
|
||||
headers.update(endpoint.headers)
|
||||
|
||||
node_id = get_node_id(cls)
|
||||
start_ts = time.monotonic()
|
||||
last_chunk_status_ts = 0.0
|
||||
audio_streams: dict[int, list[bytes]] = {}
|
||||
title: str | None = None
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=1200.0, sock_read=300.0)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
PromptServer.instance.send_progress_text("Status: Queued", node_id)
|
||||
async with session.post(url, data=form, headers=headers) as resp:
|
||||
if resp.status >= 400:
|
||||
msg = await _extract_error_message(resp)
|
||||
raise Exception(f"Sonilo API error ({resp.status}): {msg}")
|
||||
|
||||
while True:
|
||||
if is_processing_interrupted():
|
||||
raise ProcessingInterrupted("Task cancelled")
|
||||
|
||||
raw_line = await resp.content.readline()
|
||||
if not raw_line:
|
||||
break
|
||||
|
||||
line = raw_line.decode("utf-8").strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
evt = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Sonilo: skipping malformed NDJSON line")
|
||||
continue
|
||||
|
||||
evt_type = evt.get("type")
|
||||
if evt_type == "error":
|
||||
code = evt.get("code", "UNKNOWN")
|
||||
message = evt.get("message", "Unknown error")
|
||||
raise Exception(f"Sonilo generation error ({code}): {message}")
|
||||
if evt_type == "duration":
|
||||
duration_sec = evt.get("duration_sec")
|
||||
if duration_sec is not None:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Status: Generating\nVideo duration: {duration_sec:.1f}s",
|
||||
node_id,
|
||||
)
|
||||
elif evt_type in ("titles", "title"):
|
||||
# v2m sends a "titles" list, t2m sends a scalar "title"
|
||||
if evt_type == "titles":
|
||||
titles = evt.get("titles", [])
|
||||
if titles:
|
||||
title = titles[0]
|
||||
else:
|
||||
title = evt.get("title") or title
|
||||
if title:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Status: Generating\nTitle: {title}",
|
||||
node_id,
|
||||
)
|
||||
elif evt_type == "audio_chunk":
|
||||
stream_idx = evt.get("stream_index", 0)
|
||||
chunk_data = base64.b64decode(evt["data"])
|
||||
|
||||
if stream_idx not in audio_streams:
|
||||
audio_streams[stream_idx] = []
|
||||
audio_streams[stream_idx].append(chunk_data)
|
||||
|
||||
now = time.monotonic()
|
||||
if now - last_chunk_status_ts >= 1.0:
|
||||
total_chunks = sum(len(chunks) for chunks in audio_streams.values())
|
||||
elapsed = int(now - start_ts)
|
||||
status_lines = ["Status: Receiving audio"]
|
||||
if title:
|
||||
status_lines.append(f"Title: {title}")
|
||||
status_lines.append(f"Chunks received: {total_chunks}")
|
||||
status_lines.append(f"Time elapsed: {elapsed}s")
|
||||
PromptServer.instance.send_progress_text("\n".join(status_lines), node_id)
|
||||
last_chunk_status_ts = now
|
||||
elif evt_type == "complete":
|
||||
break
|
||||
|
||||
if not audio_streams:
|
||||
raise Exception("Sonilo API returned no audio data.")
|
||||
|
||||
PromptServer.instance.send_progress_text("Status: Completed", node_id)
|
||||
selected_stream = 0 if 0 in audio_streams else min(audio_streams)
|
||||
return b"".join(audio_streams[selected_stream])
|
||||
|
||||
|
||||
async def _extract_error_message(resp: aiohttp.ClientResponse) -> str:
|
||||
"""Extract a human-readable error message from an HTTP error response."""
|
||||
try:
|
||||
error_body = await resp.json()
|
||||
detail = error_body.get("detail", {})
|
||||
if isinstance(detail, dict):
|
||||
return detail.get("message", str(detail))
|
||||
return str(detail)
|
||||
except Exception:
|
||||
return await resp.text()
|
||||
|
||||
|
||||
class SoniloExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [SoniloVideoToMusic, SoniloTextToMusic]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> SoniloExtension:
|
||||
return SoniloExtension()
|
||||
89
comfy_extras/nodes_multigpu.py
Normal file
89
comfy_extras/nodes_multigpu.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from inspect import cleandoc
|
||||
from typing import TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
import comfy.multigpu
|
||||
|
||||
|
||||
class MultiGPUCFGSplitNode(io.ComfyNode):
|
||||
"""
|
||||
Prepares model to have sampling accelerated via splitting work units.
|
||||
|
||||
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
|
||||
|
||||
Other than those exceptions, this node can be placed in any order.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MultiGPU_WorkUnits",
|
||||
display_name="MultiGPU CFG Split",
|
||||
category="advanced/multigpu",
|
||||
description=cleandoc(cls.__doc__),
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Int.Input("max_gpus", default=2, min=1, step=1),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput:
|
||||
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
class MultiGPUOptionsNode(io.ComfyNode):
|
||||
"""
|
||||
Select the relative speed of GPUs in the special case they have significantly different performance from one another.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MultiGPU_Options",
|
||||
display_name="MultiGPU Options",
|
||||
category="advanced/multigpu",
|
||||
description=cleandoc(cls.__doc__),
|
||||
inputs=[
|
||||
io.Int.Input("device_index", default=0, min=0, max=64),
|
||||
io.Float.Input("relative_speed", default=1.0, min=0.0, step=0.01),
|
||||
io.Custom("GPU_OPTIONS").Input("gpu_options", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Custom("GPU_OPTIONS").Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput:
|
||||
if not gpu_options:
|
||||
gpu_options = comfy.multigpu.GPUOptionsGroup()
|
||||
gpu_options.clone()
|
||||
|
||||
opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
|
||||
gpu_options.add(opt)
|
||||
|
||||
return io.NodeOutput(gpu_options)
|
||||
|
||||
|
||||
class MultiGPUExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
MultiGPUCFGSplitNode,
|
||||
# MultiGPUOptionsNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> MultiGPUExtension:
|
||||
return MultiGPUExtension()
|
||||
@@ -11,7 +11,7 @@ class PreviewAny():
|
||||
"required": {"source": (IO.ANY, {})},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "main"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
@@ -33,7 +33,7 @@ class PreviewAny():
|
||||
except Exception:
|
||||
value = 'source exists, but could not be serialized.'
|
||||
|
||||
return {"ui": {"text": (value,)}, "result": (value,)}
|
||||
return {"ui": {"text": (value,)}}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PreviewAny": PreviewAny,
|
||||
|
||||
@@ -32,12 +32,10 @@ class RTDETR_detect(io.ComfyNode):
|
||||
def execute(cls, model, image, threshold, class_name, max_detections) -> io.NodeOutput:
|
||||
B, H, W, C = image.shape
|
||||
|
||||
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
|
||||
|
||||
comfy.model_management.load_model_gpu(model)
|
||||
results = []
|
||||
for i in range(0, B, 32):
|
||||
batch = image[i:i + 32]
|
||||
image_in = comfy.utils.common_upscale(batch.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
|
||||
results.extend(model.model.diffusion_model(image_in, (W, H)))
|
||||
results = model.model.diffusion_model(image_in, (W, H)) # list of B dicts
|
||||
|
||||
all_bbox_dicts = []
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import torch
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import numpy as np
|
||||
import math
|
||||
import colorsys
|
||||
@@ -411,9 +410,7 @@ class SDPoseDrawKeypoints(io.ComfyNode):
|
||||
pose_outputs.append(canvas)
|
||||
|
||||
pose_outputs_np = np.stack(pose_outputs) if len(pose_outputs) > 1 else np.expand_dims(pose_outputs[0], 0)
|
||||
final_pose_output = torch.from_numpy(pose_outputs_np).to(
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
dtype=comfy.model_management.intermediate_dtype()) / 255.0
|
||||
final_pose_output = torch.from_numpy(pose_outputs_np).float() / 255.0
|
||||
return io.NodeOutput(final_pose_output)
|
||||
|
||||
class SDPoseKeypointExtractor(io.ComfyNode):
|
||||
@@ -462,27 +459,6 @@ class SDPoseKeypointExtractor(io.ComfyNode):
|
||||
model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768
|
||||
model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024
|
||||
|
||||
def _resize_to_model(imgs):
|
||||
"""Aspect-preserving resize + zero-pad BHWC images to (model_h, model_w). Returns (resized_bhwc, scale, pad_top, pad_left)."""
|
||||
h, w = imgs.shape[-3], imgs.shape[-2]
|
||||
scale = min(model_h / h, model_w / w)
|
||||
sh, sw = int(round(h * scale)), int(round(w * scale))
|
||||
pt, pl = (model_h - sh) // 2, (model_w - sw) // 2
|
||||
chw = imgs.permute(0, 3, 1, 2).float()
|
||||
scaled = comfy.utils.common_upscale(chw, sw, sh, upscale_method="bilinear", crop="disabled")
|
||||
padded = torch.zeros(scaled.shape[0], scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
|
||||
padded[:, :, pt:pt + sh, pl:pl + sw] = scaled
|
||||
return padded.permute(0, 2, 3, 1), scale, pt, pl
|
||||
|
||||
def _remap_keypoints(kp, scale, pad_top, pad_left, offset_x=0, offset_y=0):
|
||||
"""Remap keypoints from model space back to original image space."""
|
||||
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
|
||||
invalid = kp[..., 0] < 0
|
||||
kp[..., 0] = (kp[..., 0] - pad_left) / scale + offset_x
|
||||
kp[..., 1] = (kp[..., 1] - pad_top) / scale + offset_y
|
||||
kp[invalid] = -1
|
||||
return kp
|
||||
|
||||
def _run_on_latent(latent_batch):
|
||||
"""Run one forward pass and return (keypoints_list, scores_list) for the batch."""
|
||||
nonlocal captured_feat
|
||||
@@ -528,19 +504,36 @@ class SDPoseKeypointExtractor(io.ComfyNode):
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
continue
|
||||
|
||||
crop_h_px, crop_w_px = y2 - y1, x2 - x1
|
||||
crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C)
|
||||
crop_resized, scale, pad_top, pad_left = _resize_to_model(crop)
|
||||
|
||||
# scale to fit inside (model_h, model_w) while preserving aspect ratio, then pad to exact model size.
|
||||
scale = min(model_h / crop_h_px, model_w / crop_w_px)
|
||||
scaled_h, scaled_w = int(round(crop_h_px * scale)), int(round(crop_w_px * scale))
|
||||
pad_top, pad_left = (model_h - scaled_h) // 2, (model_w - scaled_w) // 2
|
||||
|
||||
crop_chw = crop.permute(0, 3, 1, 2).float() # BHWC → BCHW
|
||||
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled")
|
||||
padded = torch.zeros(1, scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
|
||||
padded[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled
|
||||
crop_resized = padded.permute(0, 2, 3, 1) # BCHW → BHWC
|
||||
|
||||
latent_crop = vae.encode(crop_resized)
|
||||
kp_batch, sc_batch = _run_on_latent(latent_crop)
|
||||
kp = _remap_keypoints(kp_batch[0], scale, pad_top, pad_left, x1, y1)
|
||||
kp, sc = kp_batch[0], sc_batch[0] # (K, 2), coords in model pixel space
|
||||
|
||||
# remove padding offset, undo scale, offset to full-image coordinates.
|
||||
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
|
||||
kp[..., 0] = (kp[..., 0] - pad_left) / scale + x1
|
||||
kp[..., 1] = (kp[..., 1] - pad_top) / scale + y1
|
||||
|
||||
img_keypoints.append(kp)
|
||||
img_scores.append(sc_batch[0])
|
||||
img_scores.append(sc)
|
||||
else:
|
||||
img_resized, scale, pad_top, pad_left = _resize_to_model(img)
|
||||
latent_img = vae.encode(img_resized)
|
||||
# No bboxes for this image – run on the full image
|
||||
latent_img = vae.encode(img)
|
||||
kp_batch, sc_batch = _run_on_latent(latent_img)
|
||||
img_keypoints.append(_remap_keypoints(kp_batch[0], scale, pad_top, pad_left))
|
||||
img_keypoints.append(kp_batch[0])
|
||||
img_scores.append(sc_batch[0])
|
||||
|
||||
all_keypoints.append(img_keypoints)
|
||||
@@ -548,16 +541,19 @@ class SDPoseKeypointExtractor(io.ComfyNode):
|
||||
pbar.update(1)
|
||||
|
||||
else: # full-image mode, batched
|
||||
for batch_start in tqdm(range(0, total_images, batch_size), desc="Extracting keypoints"):
|
||||
batch_resized, scale, pad_top, pad_left = _resize_to_model(image[batch_start:batch_start + batch_size])
|
||||
latent_batch = vae.encode(batch_resized)
|
||||
tqdm_pbar = tqdm(total=total_images, desc="Extracting keypoints")
|
||||
for batch_start in range(0, total_images, batch_size):
|
||||
batch_end = min(batch_start + batch_size, total_images)
|
||||
latent_batch = vae.encode(image[batch_start:batch_end])
|
||||
|
||||
kp_batch, sc_batch = _run_on_latent(latent_batch)
|
||||
|
||||
for kp, sc in zip(kp_batch, sc_batch):
|
||||
all_keypoints.append([_remap_keypoints(kp, scale, pad_top, pad_left)])
|
||||
all_keypoints.append([kp])
|
||||
all_scores.append([sc])
|
||||
tqdm_pbar.update(1)
|
||||
|
||||
pbar.update(len(kp_batch))
|
||||
pbar.update(batch_end - batch_start)
|
||||
|
||||
openpose_frames = _to_openpose_frames(all_keypoints, all_scores, height, width)
|
||||
return io.NodeOutput(openpose_frames)
|
||||
|
||||
@@ -6,7 +6,6 @@ import comfy.utils
|
||||
import folder_paths
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import comfy.model_management
|
||||
|
||||
try:
|
||||
from spandrel_extra_arches import EXTRA_REGISTRY
|
||||
@@ -79,15 +78,13 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
tile = 512
|
||||
overlap = 32
|
||||
|
||||
output_device = comfy.model_management.intermediate_device()
|
||||
|
||||
oom = True
|
||||
try:
|
||||
while oom:
|
||||
try:
|
||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device)
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||
oom = False
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
@@ -97,7 +94,7 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
finally:
|
||||
upscale_model.to("cpu")
|
||||
|
||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype())
|
||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
||||
return io.NodeOutput(s)
|
||||
|
||||
upscale = execute # TODO: remove
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.19.1"
|
||||
__version__ = "0.18.1"
|
||||
|
||||
1
nodes.py
1
nodes.py
@@ -2412,6 +2412,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_lt_audio.py",
|
||||
"nodes_lt.py",
|
||||
"nodes_hooks.py",
|
||||
"nodes_multigpu.py",
|
||||
"nodes_load_3d.py",
|
||||
"nodes_cosmos.py",
|
||||
"nodes_video.py",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.19.1"
|
||||
version = "0.18.1"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.42.11
|
||||
comfyui-workflow-templates==0.9.54
|
||||
comfyui-frontend-package==1.42.8
|
||||
comfyui-workflow-templates==0.9.44
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
torchsde
|
||||
|
||||
Reference in New Issue
Block a user