Compare commits

..

21 Commits

Author SHA1 Message Date
Austin Mroz
3fbf3b421f WIP branch node 2026-03-27 14:35:24 -07:00
comfyanonymous
a11f68dd3b Fix canny node not working with fp16. (#13085) 2026-03-20 23:15:50 -04:00
comfyanonymous
dc719cde9c ComfyUI version 0.18.0 2026-03-20 20:09:15 -04:00
Jedrzej Kosinski
87cda1fc25 Move inline comfy.context_windows imports to top-level in model_base.py (#13083)
The recent PR that added resize_cond_for_context_window methods to
model classes used inline 'import comfy.context_windows' in each
method body. This moves that import to the top-level import section,
replacing 4 duplicate inline imports with a single top-level one.
2026-03-20 20:03:42 -04:00
comfyanonymous
45d5c83a30 Make EmptyImage node follow intermediate device/dtype. (#13079) 2026-03-20 16:08:26 -04:00
Alexander Piskun
c646d211be feat(api-nodes): add Quiver SVG nodes (#13047) 2026-03-20 12:23:16 -07:00
drozbay
589228e671 Add slice_cond and per-model context window cond resizing (#12645)
* Add slice_cond and per-model context window cond resizing

* Fix cond_value.size() call in context window cond resizing

* Expose additional advanced inputs for ContextWindowsManualNode

Necessary for WanAnimate context windows workflow, which needs cond_retain_index_list = 0 to work properly with its reference input.

---------
2026-03-19 20:42:42 -07:00
Alexander Piskun
e4455fd43a [API Nodes] mark seedream-3-0-t2i and seedance-1-0-lite models as deprecated (#13060)
* chore(api-nodes): mark seedream-3-0-t2i and seedance-1-0-lite models as deprecated

* fix(api-nodes): fixed old regression in the ByteDanceImageReference node

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-03-19 20:05:01 -07:00
rattus
f49856af57 ltx: vae: Fix missing init variable (#13074)
Forgot to push this ammendment. Previous test results apply to this.
2026-03-19 22:34:58 -04:00
rattus
82b868a45a Fix VRAM leak in tiler fallback in video VAEs (#13073)
* sd: soft_empty_cache on tiler fallback

This doesnt cost a lot and creates the expected VRAM reduction in
resource monitors when you fallback to tiler.

* wan: vae: Don't recursion in local fns (move run_up)

Moved Decoder3d’s recursive run_up out of forward into a class
method to avoid nested closure self-reference cycles. This avoids
cyclic garbage that delays garbage of tensors which in turn delays
VRAM release before tiled fallback.

* ltx: vae: Don't recursion in local fns (move run_up)

Mov the recursive run_up out of forward into a class
method to avoid nested closure self-reference cycles. This avoids
cyclic garbage that delays garbage of tensors which in turn delays
VRAM release before tiled fallback.
2026-03-19 22:30:27 -04:00
comfyanonymous
8458ae2686 Revert "fix: run text encoders on MPS GPU instead of CPU for Apple Silicon (#…" (#13070)
This reverts commit b941913f1d.
2026-03-19 15:27:55 -04:00
Jukka Seppänen
fd0261d2bc Reduce tiled decode peak memory (#13050) 2026-03-19 13:29:34 -04:00
rattus
ab14541ef7 memory: Add more exclusion criteria to pinned read (#13067) 2026-03-19 10:03:20 -07:00
rattus
6589562ae3 ltx: vae: implement chunked encoder + CPU IO chunking (Big VRAM reductions) (#13062)
* ltx: vae: add cache state to downsample block

* ltx: vae: Add time stride awareness to causal_conv_3d

* ltx: vae: Automate truncation for encoder

Other VAEs just truncate without error. Do the same.

* sd/ltx: Make chunked_io a flag in its own right

Taking this bi-direcitonal, so make it a for-purpose named flag.

* ltx: vae: implement chunked encoder + CPU IO chunking

People are doing things with big frame counts in LTX including V2V
flows. Implement the time-chunked encoder to keep the VRAM down, with
the converse of the new CPU pre-allocation technique, where the chunks
are brought from the CPU JIT.

* ltx: vae-encode: round chunk sizes more strictly

Only powers of 2 and multiple of 8 are valid due to cache slicing.
2026-03-19 10:01:12 -07:00
rattus
fabed694a2 ltx: vae: implement chunked encoder + CPU IO chunking (Big VRAM reductions) (#13062)
* ltx: vae: add cache state to downsample block

* ltx: vae: Add time stride awareness to causal_conv_3d

* ltx: vae: Automate truncation for encoder

Other VAEs just truncate without error. Do the same.

* sd/ltx: Make chunked_io a flag in its own right

Taking this bi-direcitonal, so make it a for-purpose named flag.

* ltx: vae: implement chunked encoder + CPU IO chunking

People are doing things with big frame counts in LTX including V2V
flows. Implement the time-chunked encoder to keep the VRAM down, with
the converse of the new CPU pre-allocation technique, where the chunks
are brought from the CPU JIT.

* ltx: vae-encode: round chunk sizes more strictly

Only powers of 2 and multiple of 8 are valid due to cache slicing.
2026-03-19 09:58:47 -07:00
comfyanonymous
f6b869d7d3 fp16 intermediates doen't work for some text enc models. (#13056) 2026-03-18 19:42:28 -04:00
comfyanonymous
56ff88f951 Fix regression. (#13053) 2026-03-18 18:35:25 -04:00
Jukka Seppänen
9fff091f35 Further Reduce LTX VAE decode peak RAM usage (#13052) 2026-03-18 18:32:26 -04:00
comfyanonymous
dcd659590f Make more intermediate values follow the intermediate dtype. (#13051) 2026-03-18 18:14:18 -04:00
Alexander Brown
b67ed2a45f Update comfyui-frontend-package version to 1.41.21 (#13035) 2026-03-18 16:36:39 -04:00
Alexander Piskun
06957022d4 fix(api-nodes): add support for "thought_image" in Nano Banana 2 and corrected price badges (#13038) 2026-03-18 10:21:58 -07:00
32 changed files with 765 additions and 927 deletions

View File

@@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use. All other devices will not be visible.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")

View File

@@ -93,6 +93,50 @@ class IndexListCallbacks:
return {}
def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, device, temporal_dim: int, temporal_scale: int=1, temporal_offset: int=0, retain_index_list: list[int]=[]):
if not (hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor)):
return None
cond_tensor = cond_value.cond
if temporal_dim >= cond_tensor.ndim:
return None
cond_size = cond_tensor.size(temporal_dim)
if temporal_scale == 1:
expected_size = x_in.size(window.dim) - temporal_offset
if cond_size != expected_size:
return None
if temporal_offset == 0 and temporal_scale == 1:
sliced = window.get_tensor(cond_tensor, device, dim=temporal_dim, retain_index_list=retain_index_list)
return cond_value._copy_with(sliced)
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
if temporal_offset > 0:
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
indices = [i for i in indices if 0 <= i]
else:
indices = list(window.index_list)
if not indices:
return None
if temporal_scale > 1:
scaled = []
for i in indices:
for k in range(temporal_scale):
si = i * temporal_scale + k
if si < cond_size:
scaled.append(si)
indices = scaled
if not indices:
return None
idx = tuple([slice(None)] * temporal_dim + [indices])
sliced = cond_tensor[idx].to(device)
return cond_value._copy_with(sliced)
@dataclass
class ContextSchedule:
name: str
@@ -177,10 +221,17 @@ class IndexListContextHandler(ContextHandlerABC):
new_cond_item[cond_key] = result
handled = True
break
if not handled and self._model is not None:
result = self._model.resize_cond_for_context_window(
cond_key, cond_value, window, x_in, device,
retain_index_list=self.cond_retain_index_list)
if result is not None:
new_cond_item[cond_key] = result
handled = True
if handled:
continue
if isinstance(cond_value, torch.Tensor):
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
if (self.dim < cond_value.ndim and cond_value.size(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
# Handle audio_embed (temporal dim is 1)
@@ -224,6 +275,7 @@ class IndexListContextHandler(ContextHandlerABC):
return context_windows
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
self._model = model
self.set_step(timestep, model_options)
context_windows = self.get_context_windows(model, x_in, model_options)
enumerated_context_windows = list(enumerate(context_windows))

View File

@@ -15,14 +15,13 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import torch
from enum import Enum
import math
import os
import logging
import copy
import comfy.utils
import comfy.model_management
import comfy.model_detection
@@ -39,7 +38,7 @@ import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet
import comfy.ldm.qwen_image.controlnet
import comfy.cldm.dit_embedder
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.hooks import HookGroup
@@ -65,18 +64,6 @@ class StrengthType(Enum):
CONSTANT = 1
LINEAR_UP = 2
class ControlIsolation:
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
def __init__(self, control: ControlBase):
self.control = control
self.orig_previous_controlnet = control.previous_controlnet
def __enter__(self):
self.control.previous_controlnet = None
def __exit__(self, *args):
self.control.previous_controlnet = self.orig_previous_controlnet
class ControlBase:
def __init__(self):
self.cond_hint_original = None
@@ -90,7 +77,7 @@ class ControlBase:
self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact'
self.extra_args = {}
self.previous_controlnet: Union[ControlBase, None] = None
self.previous_controlnet = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
self.concat_mask = False
@@ -98,7 +85,6 @@ class ControlBase:
self.extra_concat = None
self.extra_hooks: HookGroup = None
self.preprocess_image = lambda a: a
self.multigpu_clones: dict[torch.device, ControlBase] = {}
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
@@ -125,38 +111,17 @@ class ControlBase:
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
for device_cnet in self.multigpu_clones.values():
with ControlIsolation(device_cnet):
device_cnet.cleanup()
self.cond_hint = None
self.extra_concat = None
self.timestep_range = None
def get_models(self):
out = []
for device_cnet in self.multigpu_clones.values():
out += device_cnet.get_models_only_self()
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
return out
def get_models_only_self(self):
'Calls get_models, but temporarily sets previous_controlnet to None.'
with ControlIsolation(self):
return self.get_models()
def get_instance_for_device(self, device):
'Returns instance of this Control object intended for selected device.'
return self.multigpu_clones.get(device, self)
def deepclone_multigpu(self, load_device, autoregister=False):
'''
Create deep clone of Control object where model(s) is set to other devices.
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
'''
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
def get_extra_hooks(self):
out = []
if self.extra_hooks is not None:
@@ -165,7 +130,7 @@ class ControlBase:
out += self.previous_controlnet.get_extra_hooks()
return out
def copy_to(self, c: ControlBase):
def copy_to(self, c):
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range
@@ -319,14 +284,6 @@ class ControlNet(ControlBase):
self.copy_to(c)
return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.control_model = copy.deepcopy(c.control_model)
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
if autoregister:
self.multigpu_clones[load_device] = c
return c
def get_models(self):
out = super().get_models()
out.append(self.control_model_wrapped)
@@ -949,14 +906,6 @@ class T2IAdapter(ControlBase):
self.copy_to(c)
return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.t2i_model = copy.deepcopy(c.t2i_model)
c.device = load_device
if autoregister:
self.multigpu_clones[load_device] = c
return c
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
compression_ratio = 8
upscale_algorithm = 'nearest-exact'

View File

@@ -23,6 +23,11 @@ class CausalConv3d(nn.Module):
self.in_channels = in_channels
self.out_channels = out_channels
if isinstance(stride, int):
self.time_stride = stride
else:
self.time_stride = stride[0]
kernel_size = (kernel_size, kernel_size, kernel_size)
self.time_kernel_size = kernel_size[0]
@@ -58,18 +63,23 @@ class CausalConv3d(nn.Module):
pieces = [ cached, x ]
if is_end and not causal:
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
input_length = sum([piece.shape[2] for piece in pieces])
cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride)
needs_caching = not is_end
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
if needs_caching and cache_length == 0:
self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False)
needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
if needs_caching and x.shape[2] >= cache_length:
needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
x = torch.cat(pieces, dim=2)
del pieces
del cached
if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
elif is_end:
self.temporal_cache_state[tid] = (None, True)

View File

@@ -233,10 +233,7 @@ class Encoder(nn.Module):
self.gradient_checkpointing = False
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]:
sample = self.conv_in(sample)
checkpoint_fn = (
@@ -247,10 +244,14 @@ class Encoder(nn.Module):
for down_block in self.down_blocks:
sample = checkpoint_fn(down_block)(sample)
if sample is None or sample.shape[2] == 0:
return None
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if sample is None or sample.shape[2] == 0:
return None
if self.latent_log_var == "uniform":
last_channel = sample[:, -1:, ...]
@@ -282,9 +283,35 @@ class Encoder(nn.Module):
return sample
def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder
frame_size = sample[:, :, :1, :, :].numel() * sample.element_size()
frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels))
outputs = []
samples = [sample[:, :, :1, :, :]]
if sample.shape[2] > 1:
chunk_t = max(2, max_chunk_size // frame_size)
if chunk_t < 4:
chunk_t = 2
elif chunk_t < 8:
chunk_t = 4
else:
chunk_t = (chunk_t // 8) * 8
samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2))
for chunk_idx, chunk in enumerate(samples):
if chunk_idx == len(samples) - 1:
mark_conv3d_ended(self)
chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device)
output = self._forward_chunk(chunk)
if output is not None:
outputs.append(output)
return torch_cat_if_needed(outputs, dim=2)
def forward(self, *args, **kwargs):
#No encoder support so just flag the end so it doesnt use the cache.
mark_conv3d_ended(self)
try:
return self.forward_orig(*args, **kwargs)
finally:
@@ -473,6 +500,17 @@ class Decoder(nn.Module):
self.gradient_checkpointing = False
# Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset)
ts, hs, ws, to = 1, 1, 1, 0
for block in self.up_blocks:
if isinstance(block, DepthToSpaceUpsample):
ts *= block.stride[0]
hs *= block.stride[1]
ws *= block.stride[2]
if block.stride[0] > 1:
to = to * block.stride[0] + 1
self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to)
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
@@ -494,11 +532,62 @@ class Decoder(nn.Module):
)
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
def decode_output_shape(self, input_shape):
c, (ts, hs, ws), to = self._output_scale
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size):
sample = sample_ref[0]
sample_ref[0] = None
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
t = sample.shape[2]
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
output_offset[0] += t
return
up_block = self.up_blocks[idx]
if ended:
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
if sample is None or sample.shape[2] == 0:
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
if num_chunks == 1:
# when we are not chunking, detach our x so the callee can free it as soon as they are done
next_sample_ref = [sample]
del sample
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
return
else:
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
for chunk_idx, sample1 in enumerate(samples):
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
def forward_orig(
self,
sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None,
output_buffer: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0]
@@ -513,6 +602,7 @@ class Decoder(nn.Module):
)
timestep_shift_scale = None
scaled_timestep = None
if self.timestep_conditioning:
assert (
timestep is not None
@@ -540,59 +630,18 @@ class Decoder(nn.Module):
)
timestep_shift_scale = ada_values.unbind(dim=1)
output = []
if output_buffer is None:
output_buffer = torch.empty(
self.decode_output_shape(sample.shape),
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
)
output_offset = [0]
max_chunk_size = get_max_chunk_size(sample.device)
def run_up(idx, sample_ref, ended):
sample = sample_ref[0]
sample_ref[0] = None
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
output.append(sample.to(comfy.model_management.intermediate_device()))
return
self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
up_block = self.up_blocks[idx]
if (ended):
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
if sample is None or sample.shape[2] == 0:
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
if num_chunks == 1:
# when we are not chunking, detach our x so the callee can free it as soon as they are done
next_sample_ref = [sample]
del sample
run_up(idx + 1, next_sample_ref, ended)
return
else:
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)
run_up(0, [sample], True)
sample = torch.cat(output, dim=2)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample
return output_buffer
def forward(self, *args, **kwargs):
try:
@@ -716,12 +765,25 @@ class SpaceToDepthDownsample(nn.Module):
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.temporal_cache_state = {}
def forward(self, x, causal: bool = True):
if self.stride[0] == 2:
tid = threading.get_ident()
cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None))
if cached_input is not None:
x = torch_cat_if_needed([cached_input, x], dim=2)
cached_input = None
if self.stride[0] == 2 and pad_first:
x = torch.cat(
[x[:, :, :1, :, :], x], dim=2
) # duplicate first frames for padding
pad_first = False
if x.shape[2] < self.stride[0]:
cached_input = x
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
return None
# skip connection
x_in = rearrange(
@@ -736,15 +798,26 @@ class SpaceToDepthDownsample(nn.Module):
# conv
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2 and x.shape[2] == 1:
if cached_x is not None:
x = torch_cat_if_needed([cached_x, x], dim=2)
cached_x = None
else:
cached_x = x
x = None
x = x + x_in
if x is not None:
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
cached = add_exchange_cache(x, cached, x_in, dim=2)
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
return x
@@ -1077,6 +1150,8 @@ class processor(nn.Module):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module):
comfy_has_chunked_io = True
def __init__(self, version=0, config=None):
super().__init__()
@@ -1219,14 +1294,15 @@ class VideoVAE(nn.Module):
}
return config
def encode(self, x):
frames_count = x.shape[2]
if ((frames_count - 1) % 8) != 0:
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
def encode(self, x, device=None):
x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :]
means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1)
return self.per_channel_statistics.normalize(means)
def decode(self, x):
def decode_output_shape(self, input_shape):
return self.decoder.decode_output_shape(input_shape)
def decode(self, x, output_buffer=None):
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer)

View File

@@ -360,6 +360,43 @@ class Decoder3d(nn.Module):
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, output_channels, 3, padding=1))
def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks):
x = x_ref[0]
x_ref[0] = None
if layer_idx >= len(self.upsamples):
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[feat_idx[0]])
feat_cache[feat_idx[0]] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
out_chunks.append(x)
return
layer = self.upsamples[layer_idx]
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
for frame_idx in range(x.shape[2]):
self.run_up(
layer_idx,
[x[:, :, frame_idx:frame_idx + 1, :, :]],
feat_cache,
feat_idx.copy(),
out_chunks,
)
del x
return
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
next_x_ref = [x]
del x
self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks)
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
@@ -380,42 +417,7 @@ class Decoder3d(nn.Module):
out_chunks = []
def run_up(layer_idx, x_ref, feat_idx):
x = x_ref[0]
x_ref[0] = None
if layer_idx >= len(self.upsamples):
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[feat_idx[0]])
feat_cache[feat_idx[0]] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
out_chunks.append(x)
return
layer = self.upsamples[layer_idx]
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
for frame_idx in range(x.shape[2]):
run_up(
layer_idx,
[x[:, :, frame_idx:frame_idx + 1, :, :]],
feat_idx.copy(),
)
del x
return
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
next_x_ref = [x]
del x
run_up(layer_idx + 1, next_x_ref, feat_idx)
run_up(0, [x], feat_idx)
self.run_up(0, [x], feat_cache, feat_idx, out_chunks)
return out_chunks

View File

@@ -39,7 +39,10 @@ def read_tensor_file_slice_into(tensor, destination):
if (destination.device.type != "cpu"
or file_obj is None
or threading.get_ident() != info.thread_id
or destination.numel() * destination.element_size() < info.size):
or destination.numel() * destination.element_size() < info.size
or tensor.numel() * tensor.element_size() != info.size
or tensor.storage_offset() != 0
or not tensor.is_contiguous()):
return False
if info.size == 0:

View File

@@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch
import logging
import comfy.ldm.lightricks.av_model
import comfy.context_windows
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.cascade.stage_c import StageC
from comfy.ldm.cascade.stage_b import StageB
@@ -285,6 +286,12 @@ class BaseModel(torch.nn.Module):
return data
return None
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
"""Override in subclasses to handle model-specific cond slicing for context windows.
Return a sliced cond object, or None to fall through to default handling.
Use comfy.context_windows.slice_cond() for common cases."""
return None
def extra_conds(self, **kwargs):
out = {}
concat_cond = self.concat_cond(**kwargs)
@@ -1375,6 +1382,11 @@ class WAN21_Vace(WAN21):
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "vace_context":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN21_Camera(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
@@ -1427,6 +1439,11 @@ class WAN21_HuMo(WAN21):
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "audio_embed":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN22_Animate(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
@@ -1444,6 +1461,13 @@ class WAN22_Animate(WAN21):
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "face_pixel_values":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1)
if cond_key == "pose_latents":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN22_S2V(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
@@ -1480,6 +1504,11 @@ class WAN22_S2V(WAN21):
out['reference_motion'] = reference_motion.shape
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "audio_embed":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN22(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)

View File

@@ -15,7 +15,6 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import psutil
import logging
@@ -33,11 +32,6 @@ import comfy.memory_management
import comfy.utils
import comfy.quant_ops
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
@@ -211,25 +205,6 @@ def get_torch_device():
else:
return torch.device(torch.cuda.current_device())
def get_all_torch_devices(exclude_current=False):
global cpu_state
devices = []
if cpu_state == CPUState.GPU:
if is_nvidia():
for i in range(torch.cuda.device_count()):
devices.append(torch.device(i))
elif is_intel_xpu():
for i in range(torch.xpu.device_count()):
devices.append(torch.device(i))
elif is_ascend_npu():
for i in range(torch.npu.device_count()):
devices.append(torch.device(i))
else:
devices.append(get_torch_device())
if exclude_current:
devices.remove(get_torch_device())
return devices
def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled
if dev is None:
@@ -518,13 +493,9 @@ try:
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
except:
logging.warning("Could not pick default device.")
try:
for device in get_all_torch_devices(exclude_current=True):
logging.info("Device: {}".format(get_torch_device_name(device)))
except:
pass
current_loaded_models: list[LoadedModel] = []
current_loaded_models = []
def module_size(module):
module_mem = 0
@@ -557,7 +528,7 @@ def module_mmap_residency(module, free=False):
return mmap_touched_mem, module_mem
class LoadedModel:
def __init__(self, model: ModelPatcher):
def __init__(self, model):
self._set_model(model)
self.device = model.load_device
self.real_model = None
@@ -565,7 +536,7 @@ class LoadedModel:
self.model_finalizer = None
self._patcher_finalizer = None
def _set_model(self, model: ModelPatcher):
def _set_model(self, model):
self._model = weakref.ref(model)
if model.parent is not None:
self._parent_model = weakref.ref(model.parent)
@@ -1032,7 +1003,7 @@ def text_encoder_offload_device():
def text_encoder_device():
if args.gpu_only:
return get_torch_device()
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.SHARED) or comfy.memory_management.aimdo_enabled:
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled:
if should_use_fp16(prioritize_performance=False):
return get_torch_device()
else:
@@ -1809,34 +1780,7 @@ def soft_empty_cache(force=False):
torch.cuda.ipc_collect()
def unload_all_models():
for device in get_all_torch_devices():
free_memory(1e30, device)
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
'Unload only model and its clones - primarily for multigpu cloning purposes.'
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
additional_models = []
if unload_additional_models:
additional_models = model.get_nested_additional_models()
keep_loaded = []
for loaded_model in initial_keep_loaded:
if loaded_model.model is not None:
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
continue
# check additional models if they are a match
skip = False
for add_model in additional_models:
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
skip = True
break
if skip:
continue
keep_loaded.append(loaded_model)
if not all_devices:
free_memory(1e30, get_torch_device(), keep_loaded)
else:
for device in get_all_torch_devices():
free_memory(1e30, device, keep_loaded)
free_memory(1e30, get_torch_device())
def debug_memory_summary():
if is_amd() or is_nvidia():

View File

@@ -23,7 +23,6 @@ import inspect
import logging
import math
import uuid
import copy
from typing import Callable, Optional
import torch
@@ -76,15 +75,12 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
def create_model_options_clone(orig_model_options: dict):
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
def create_hook_patches_clone(orig_hook_patches):
new_hook_patches = {}
for hook_ref in orig_hook_patches:
new_hook_patches[hook_ref] = {}
for k in orig_hook_patches[hook_ref]:
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
if copy_tuples:
for i in range(len(new_hook_patches[hook_ref][k])):
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
return new_hook_patches
def wipe_lowvram_weight(m):
@@ -276,10 +272,7 @@ class ModelPatcher:
self.is_clip = False
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
self.cached_patcher_init: tuple[Callable, tuple] | tuple[Callable, tuple, int] | None = None
self.is_multigpu_base_clone = False
self.clone_base_uuid = uuid.uuid4()
self.cached_patcher_init: tuple[Callable, tuple] | None = None
if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0
@@ -336,8 +329,6 @@ class ModelPatcher:
if self.cached_patcher_init is None:
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
if len(self.cached_patcher_init) > 2:
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
model_override = temp_model_patcher.get_clone_model_override()
if model_override is None:
model_override = self.get_clone_model_override()
@@ -396,98 +387,19 @@ class ModelPatcher:
n.hook_mode = self.hook_mode
n.cached_patcher_init = self.cached_patcher_init
n.is_multigpu_base_clone = self.is_multigpu_base_clone
n.clone_base_uuid = self.clone_base_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n)
return n
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
comfy.model_management.unload_model_and_clones(self)
n = self.clone()
# set load device, if present
if new_load_device is not None:
n.load_device = new_load_device
if self.cached_patcher_init is not None:
temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
if len(self.cached_patcher_init) > 2:
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
n.model = temp_model_patcher.model
else:
n.model = copy.deepcopy(n.model)
# unlike for normal clone, backup dicts that shared same ref should not;
# otherwise, patchers that have deep copies of base models will erroneously influence each other.
n.backup = copy.deepcopy(n.backup)
n.object_patches_backup = copy.deepcopy(n.object_patches_backup)
n.hook_backup = copy.deepcopy(n.hook_backup)
# multigpu clone should not have multigpu additional_models entry
n.remove_additional_models("multigpu")
# multigpu_clone all stored additional_models; make sure circular references are properly handled
if models_cache is None:
models_cache = {}
for key, model_list in n.additional_models.items():
for i in range(len(model_list)):
add_model = n.additional_models[key][i]
if add_model.clone_base_uuid not in models_cache:
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
callback(self, n)
return n
def match_multigpu_clones(self):
multigpu_models = self.get_additional_models_with_key("multigpu")
if len(multigpu_models) > 0:
new_multigpu_models = []
for mm in multigpu_models:
# clone main model, but bring over relevant props from existing multigpu clone
n = self.clone()
n.load_device = mm.load_device
n.backup = mm.backup
n.object_patches_backup = mm.object_patches_backup
n.hook_backup = mm.hook_backup
n.model = mm.model
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
n.remove_additional_models("multigpu")
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
# figure out which additional models are not present in multigpu clone
models_cache = {}
for mm_add_model in mm.get_additional_models():
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
remove_models_uuids = set(list(models_cache.keys()))
for key, model_list in orig_additional_models.items():
for orig_add_model in model_list:
if orig_add_model.clone_base_uuid not in models_cache:
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
existing_list = n.get_additional_models_with_key(key)
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
n.set_additional_models(key, existing_list)
if orig_add_model.clone_base_uuid in remove_models_uuids:
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
# remove duplicate additional models
for key, model_list in n.additional_models.items():
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
n.set_additional_models(key, new_model_list)
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
callback(self, n)
new_multigpu_models.append(n)
self.set_additional_models("multigpu", new_multigpu_models)
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
if allow_multigpu:
if self.clone_base_uuid != clone.clone_base_uuid:
return False
else:
if not self.is_clone(clone):
return False
def clone_has_same_weights(self, clone: 'ModelPatcher'):
if not self.is_clone(clone):
return False
if self.current_hooks != clone.current_hooks:
return False
@@ -1258,7 +1170,7 @@ class ModelPatcher:
return self.additional_models.get(key, [])
def get_additional_models(self):
all_models: list[ModelPatcher] = []
all_models = []
for models in self.additional_models.values():
all_models.extend(models)
return all_models
@@ -1312,13 +1224,9 @@ class ModelPatcher:
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self)
def prepare_state(self, timestep, model_options, ignore_multigpu=False):
def prepare_state(self, timestep):
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
callback(self, timestep, model_options, ignore_multigpu)
if not ignore_multigpu and "multigpu_clones" in model_options:
for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p.prepare_state(timestep, model_options, ignore_multigpu=True)
callback(self, timestep)
def restore_hook_patches(self):
if self.hook_patches_backup is not None:
@@ -1331,18 +1239,12 @@ class ModelPatcher:
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
curr_t = t[0]
reset_current_hooks = False
multigpu_kf_changed_cache = None
transformer_options = model_options.get("transformer_options", {})
for hook in hook_group.hooks:
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
# this will cause the weights to be recalculated when sampling
if changed:
# cache changed for multigpu usage
if "multigpu_clones" in model_options:
if multigpu_kf_changed_cache is None:
multigpu_kf_changed_cache = []
multigpu_kf_changed_cache.append(hook)
# reset current_hooks if contains hook that changed
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
@@ -1354,28 +1256,6 @@ class ModelPatcher:
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
if "multigpu_clones" in model_options:
for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
if kf_changed_cache is None:
return
reset_current_hooks = False
# reset current_hooks if contains hook that changed
for hook in kf_changed_cache:
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
if current_hook == hook:
reset_current_hooks = True
break
for cached_group in list(self.cached_hook_patches.keys()):
if cached_group.contains(hook):
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
registered: comfy.hooks.HookGroup = None):

View File

@@ -1,167 +0,0 @@
from __future__ import annotations
import torch
import logging
from collections import namedtuple
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
import comfy.utils
import comfy.patcher_extension
import comfy.model_management
class GPUOptions:
def __init__(self, device_index: int, relative_speed: float):
self.device_index = device_index
self.relative_speed = relative_speed
def clone(self):
return GPUOptions(self.device_index, self.relative_speed)
def create_dict(self):
return {
"relative_speed": self.relative_speed
}
class GPUOptionsGroup:
def __init__(self):
self.options: dict[int, GPUOptions] = {}
def add(self, info: GPUOptions):
self.options[info.device_index] = info
def clone(self):
c = GPUOptionsGroup()
for opt in self.options.values():
c.add(opt)
return c
def register(self, model: ModelPatcher):
opts_dict = {}
# get devices that are valid for this model
devices: list[torch.device] = [model.load_device]
for extra_model in model.get_additional_models_with_key("multigpu"):
extra_model: ModelPatcher
devices.append(extra_model.load_device)
# create dictionary with actual device mapped to its GPUOptions
device_opts_list: list[GPUOptions] = []
for device in devices:
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
opts_dict[device] = device_opts.create_dict()
device_opts_list.append(device_opts)
# make relative_speed relative to 1.0
min_speed = min([x.relative_speed for x in device_opts_list])
for value in opts_dict.values():
value['relative_speed'] /= min_speed
model.model_options['multigpu_options'] = opts_dict
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
model = model.clone()
# check if multigpu is already prepared - get the load devices from them if possible to exclude
skip_devices = set()
multigpu_models = model.get_additional_models_with_key("multigpu")
if len(multigpu_models) > 0:
for mm in multigpu_models:
skip_devices.add(mm.load_device)
skip_devices = list(skip_devices)
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
limit_extra_devices = full_extra_devices[:max_gpus-1]
extra_devices = limit_extra_devices.copy()
# exclude skipped devices
for skip in skip_devices:
if skip in extra_devices:
extra_devices.remove(skip)
# create new deepclones
if len(extra_devices) > 0:
for device in extra_devices:
device_patcher = None
if reuse_loaded:
# check if there are any ModelPatchers currently loaded that could be referenced here after a clone
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
for lm in loaded_models:
if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device:
device_patcher = lm.clone()
logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}")
break
if device_patcher is None:
device_patcher = model.deepclone_multigpu(new_load_device=device)
device_patcher.is_multigpu_base_clone = True
multigpu_models = model.get_additional_models_with_key("multigpu")
multigpu_models.append(device_patcher)
model.set_additional_models("multigpu", multigpu_models)
model.match_multigpu_clones()
if gpu_options is None:
gpu_options = GPUOptionsGroup()
gpu_options.register(model)
else:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
# TODO: only keep model clones that don't go 'past' the intended max_gpu count
# multigpu_models = model.get_additional_models_with_key("multigpu")
# new_multigpu_models = []
# for m in multigpu_models:
# if m.load_device in limit_extra_devices:
# new_multigpu_models.append(m)
# model.set_additional_models("multigpu", new_multigpu_models)
# persist skip_devices for use in sampling code
# if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options:
# model.model_options["multigpu_skip_devices"] = skip_devices
return model
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
opts_dict = model_options['multigpu_options']
devices = list(model_options['multigpu_clones'].keys())
speed_per_device = []
work_per_device = []
# get sum of each device's relative_speed
total_speed = 0.0
for opts in opts_dict.values():
total_speed += opts['relative_speed']
# get relative work for each device;
# obtained by w = (W*r)/R
for device in devices:
relative_speed = opts_dict[device]['relative_speed']
relative_work = (total_work*relative_speed) / total_speed
speed_per_device.append(relative_speed)
work_per_device.append(relative_work)
# relative work must be expressed in whole numbers, but likely is a decimal;
# perform rounding while maintaining total sum equal to total work (sum of relative works)
work_per_device = round_preserved(work_per_device)
dict_work_per_device = {}
for device, relative_work in zip(devices, work_per_device):
dict_work_per_device[device] = relative_work
if not return_idle_time:
return LoadBalance(dict_work_per_device, None)
# divide relative work by relative speed to get estimated completion time of said work by each device;
# time here is relative and does not correspond to real-world units
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
# calculate relative time spent by the devices waiting on each other after their work is completed
idle_time = abs(min(completion_time) - max(completion_time))
# if need to compare work idle time, need to normalize to a common total work
if work_normalized:
idle_time *= (work_normalized/total_work)
return LoadBalance(dict_work_per_device, idle_time)
def round_preserved(values: list[float]):
'Round all values in a list, preserving the combined sum of values.'
# get floor of values; casting to int does it too
floored = [int(x) for x in values]
total_floored = sum(floored)
# get remainder to distribute
remainder = round(sum(values)) - total_floored
# pair values with fractional portions
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
# sort by fractional part in descending order
fractional.sort(key=lambda x: x[1], reverse=True)
# distribute the remainder
for i in range(remainder):
index = fractional[i][0]
floored[index] += 1
return floored

View File

@@ -3,8 +3,6 @@ from typing import Callable
class CallbacksMP:
ON_CLONE = "on_clone"
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
ON_LOAD = "on_load_after"
ON_DETACH = "on_detach_after"
ON_CLEANUP = "on_cleanup"

View File

@@ -20,7 +20,7 @@ try:
if cuda_version < (13,):
ck.registry.disable("cuda")
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
ck.registry.disable("cuda") # multigpu will not work rn with comfy-kitchen on cuda backend
ck.registry.disable("triton")
for k, v in ck.list_backends().items():
logging.info(f"Found comfy_kitchen backend {k}: {v}")

View File

@@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device())
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device())
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return samples

View File

@@ -1,17 +1,16 @@
from __future__ import annotations
import torch
import uuid
import math
import collections
import comfy.model_management
import comfy.conds
import comfy.model_patcher
import comfy.utils
import comfy.hooks
import comfy.patcher_extension
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.controlnet import ControlBase
def prepare_mask(noise_mask, shape, device):
@@ -119,47 +118,6 @@ def cleanup_additional_models(models):
if hasattr(m, 'cleanup'):
m.cleanup()
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
if len(multigpu_models) == 0:
return
extra_devices = [x.load_device for x in multigpu_models]
# handle controlnets
controlnets: set[ControlBase] = set()
for k in conds:
for kk in conds[k]:
if 'control' in kk:
controlnets.add(kk['control'])
if len(controlnets) > 0:
# first, unload all controlnet clones
for cnet in list(controlnets):
cnet_models = cnet.get_models()
for cm in cnet_models:
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
# next, make sure each controlnet has a deepclone for all relevant devices
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
for device in extra_devices:
if device not in curr_cnet.multigpu_clones:
curr_cnet.deepclone_multigpu(device, autoregister=True)
curr_cnet = curr_cnet.previous_controlnet
# since all device clones are now present, recreate the linked list for cloned cnets per device
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
prev_cnet = curr_cnet.previous_controlnet
for device in extra_devices:
device_cnet = curr_cnet.get_instance_for_device(device)
prev_device_cnet = None
if prev_cnet is not None:
prev_device_cnet = prev_cnet.get_instance_for_device(device)
device_cnet.set_previous_controlnet(prev_device_cnet)
curr_cnet = prev_cnet
# potentially handle gligen - since not widely used, ignored for now
def estimate_memory(model, noise_shape, conds):
cond_shapes = collections.defaultdict(list)
cond_shapes_min = {}
@@ -184,8 +142,7 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
model.match_multigpu_clones()
preprocess_multigpu_conds(conds, model, model_options)
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
@@ -197,7 +154,7 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
memory_required += inference_memory
minimum_memory_required += inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
real_model: BaseModel = model.model
real_model = model.model
return real_model, conds, models
@@ -243,18 +200,3 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
copy_dict1=False)
return to_load_options
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
'''
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
'''
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
if len(multigpu_patchers) > 0:
multigpu_dict: dict[torch.device, ModelPatcher] = {}
multigpu_dict[model_patcher.load_device] = model_patcher
for x in multigpu_patchers:
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
multigpu_dict[x.load_device] = x
model_options["multigpu_clones"] = multigpu_dict
return multigpu_patchers

View File

@@ -1,9 +1,7 @@
from __future__ import annotations
import comfy.model_management
from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc
from typing import TYPE_CHECKING, Callable, NamedTuple, Any
from typing import TYPE_CHECKING, Callable, NamedTuple
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
@@ -21,7 +19,6 @@ import comfy.context_windows
import comfy.utils
import scipy.stats
import numpy
import threading
def add_area_dims(area, num_dims):
@@ -144,7 +141,7 @@ def can_concat_cond(c1, c2):
return cond_equal_size(c1.conditioning, c2.conditioning)
def cond_cat(c_list, device=None):
def cond_cat(c_list):
temp = {}
for x in c_list:
for k in x:
@@ -156,8 +153,6 @@ def cond_cat(c_list, device=None):
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
if device is not None and hasattr(out[k], 'to'):
out[k] = out[k].to(device)
return out
@@ -217,9 +212,7 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
)
return executor.execute(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
if 'multigpu_clones' in model_options:
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
out_conds = []
out_counts = []
# separate conds by matching hooks
@@ -251,7 +244,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep, model_options)
model.current_patcher.prepare_state(timestep)
# run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items():
@@ -352,196 +345,6 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
return out_conds
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
out_conds = []
out_counts = []
# separate conds by matching hooks
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
default_conds = []
has_default_conds = False
output_device = x_in.device
for i in range(len(conds)):
out_conds.append(torch.zeros_like(x_in))
out_counts.append(torch.ones_like(x_in) * 1e-37)
cond = conds[i]
default_c = []
if cond is not None:
for x in cond:
if 'default' in x:
default_c.append(x)
has_default_conds = True
continue
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
default_conds.append(default_c)
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep, model_options)
devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()]
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
total_conds = 0
for to_run in hooked_to_run.values():
total_conds += len(to_run)
conds_per_device = max(1, math.ceil(total_conds//len(devices)))
index_device = 0
current_device = devices[index_device]
# run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items():
while len(to_run) > 0:
current_device = devices[index_device % len(devices)]
batched_to_run = device_batched_hooked_to_run.setdefault(current_device, [])
# keep track of conds currently scheduled onto this device
batched_to_run_length = 0
for btr in batched_to_run:
batched_to_run_length += len(btr[1])
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
# make sure not over conds_per_device limit when creating temp batch
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < (conds_per_device - batched_to_run_length):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = comfy.model_management.get_free_memory(current_device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) * 1.5 < free_memory:
to_batch = batch_amount
break
conds_to_batch = []
for x in to_batch:
conds_to_batch.append(to_run.pop(x))
batched_to_run_length += len(conds_to_batch)
batched_to_run.append((hooks, conds_to_batch))
if batched_to_run_length >= conds_per_device:
index_device += 1
class thread_result(NamedTuple):
output: Any
mult: Any
area: Any
batch_chunks: int
cond_or_uncond: Any
error: Exception = None
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
try:
model_current: BaseModel = model_options["multigpu_clones"][device].model
# run every hooked_to_run separately
with torch.no_grad():
for hooks, to_batch in batch_tuple:
input_x = []
mult = []
c = []
cond_or_uncond = []
uuids = []
area = []
control: ControlBase = None
patches = None
for x in to_batch:
o = x
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
uuids.append(p.uuid)
control = p.control
patches = p.patches
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x).to(device)
c = cond_cat(c, device=device)
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options:
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
model_options['transformer_options'],
copy_dict1=False)
if patches is not None:
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
transformer_options.get("patches", {}),
patches
)
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["uuids"] = uuids[:]
transformer_options["sigmas"] = timestep.to(device)
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
transformer_options["multigpu_thread_device"] = device
cast_transformer_options(transformer_options, device=device)
c['transformer_options'] = transformer_options
if control is not None:
device_control = control.get_instance_for_device(device)
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
else:
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
except Exception as e:
results.append(thread_result(None, None, None, None, None, error=e))
raise
results: list[thread_result] = []
threads: list[threading.Thread] = []
for device, batch_tuple in device_batched_hooked_to_run.items():
new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results))
threads.append(new_thread)
new_thread.start()
for thread in threads:
thread.join()
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
if error is not None:
raise error
for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
a = area[o]
if a is None:
out_conds[cond_index] += output[o] * mult[o]
out_counts[cond_index] += mult[o]
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
dims = len(a) // 2
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]
for i in range(len(out_conds)):
out_conds[i] /= out_counts[i]
return out_conds
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
@@ -846,8 +649,6 @@ def pre_run_control(model, conds):
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x:
x['control'].pre_run(model, percent_to_timestep_function)
for device_cnet in x['control'].multigpu_clones.values():
device_cnet.pre_run(model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = []
@@ -1090,9 +891,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
to_load_options = model_options.get("to_load_options", None)
if to_load_options is None:
return
cast_transformer_options(to_load_options, device, dtype)
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
casts = []
if device is not None:
casts.append(device)
@@ -1101,17 +900,18 @@ def cast_transformer_options(transformer_options: dict[str], device=None, dtype=
# if nothing to apply, do nothing
if len(casts) == 0:
return
# try to call .to on patches
if "patches" in transformer_options:
patches = transformer_options["patches"]
if "patches" in to_load_options:
patches = to_load_options["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
for cast in casts:
patch_list[i] = patch_list[i].to(cast)
if "patches_replace" in transformer_options:
patches = transformer_options["patches_replace"]
if "patches_replace" in to_load_options:
patches = to_load_options["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
@@ -1121,8 +921,8 @@ def cast_transformer_options(transformer_options: dict[str], device=None, dtype=
# try to call .to on any wrappers/callbacks
wrappers_and_callbacks = ["wrappers", "callbacks"]
for wc_name in wrappers_and_callbacks:
if wc_name in transformer_options:
wc: dict[str, list] = transformer_options[wc_name]
if wc_name in to_load_options:
wc: dict[str, list] = to_load_options[wc_name]
for wc_dict in wc.values():
for wc_list in wc_dict.values():
for i in range(len(wc_list)):
@@ -1130,6 +930,7 @@ def cast_transformer_options(transformer_options: dict[str], device=None, dtype=
for cast in casts:
wc_list[i] = wc_list[i].to(cast)
class CFGGuider:
def __init__(self, model_patcher: ModelPatcher):
self.model_patcher = model_patcher
@@ -1184,8 +985,6 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
@@ -1193,13 +992,9 @@ class CFGGuider:
try:
self.model_patcher.pre_run()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
self.model_patcher.cleanup()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model

View File

@@ -951,12 +951,23 @@ class VAE:
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
# Pre-allocate output for VAEs that support direct buffer writes
preallocated = False
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
preallocated = True
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number] = out
if preallocated:
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
else:
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number].copy_(out)
del out
self.process_output(pixel_samples[x:x+batch_number])
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
@@ -967,6 +978,7 @@ class VAE:
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
@@ -1027,8 +1039,13 @@ class VAE:
batch_number = max(1, batch_number)
samples = None
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
out = self.first_stage_model.encode(pixels_in, device=self.device)
else:
pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in)
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples[x:x + batch_number] = out
@@ -1043,6 +1060,7 @@ class VAE:
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
@@ -1554,7 +1572,10 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
if output_model and out[0] is not None:
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
if output_clip and out[1] is not None:
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
return out
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):

View File

@@ -46,7 +46,7 @@ class ClipTokenWeightEncoder:
out, pooled = o[:2]
if pooled is not None:
first_pooled = pooled[0:1].to(model_management.intermediate_device())
first_pooled = pooled[0:1].to(device=model_management.intermediate_device())
else:
first_pooled = pooled
@@ -63,16 +63,16 @@ class ClipTokenWeightEncoder:
output.append(z)
if (len(output) == 0):
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled)
else:
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled)
if len(o) > 2:
extra = {}
for k in o[2]:
v = o[2][k]
if k == "attention_mask":
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device())
extra[k] = v
r = r + (extra,)

View File

@@ -1135,8 +1135,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
pbar.update(1)
continue
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
out = output[b:b+1].zero_()
out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device)
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
@@ -1151,7 +1151,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
upscaled.append(round(get_pos(d, pos)))
ps = function(s_in).to(output_device)
mask = torch.ones_like(ps)
mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device)
for d in range(2, dims + 2):
feather = round(get_scale(d - 2, overlap[d - 2]))
@@ -1174,7 +1174,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
if pbar is not None:
pbar.update(1)
output[b:b+1] = out/out_div
out.div_(out_div)
return output
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):

View File

@@ -67,6 +67,7 @@ class GeminiPart(BaseModel):
inlineData: GeminiInlineData | None = Field(None)
fileData: GeminiFileData | None = Field(None)
text: str | None = Field(None)
thought: bool | None = Field(None)
class GeminiTextPart(BaseModel):

View File

@@ -0,0 +1,43 @@
from pydantic import BaseModel, Field
class QuiverImageObject(BaseModel):
url: str = Field(...)
class QuiverTextToSVGRequest(BaseModel):
model: str = Field(default="arrow-preview")
prompt: str = Field(...)
instructions: str | None = Field(default=None)
references: list[QuiverImageObject] | None = Field(default=None, max_length=4)
temperature: float | None = Field(default=None, ge=0, le=2)
top_p: float | None = Field(default=None, ge=0, le=1)
presence_penalty: float | None = Field(default=None, ge=-2, le=2)
class QuiverImageToSVGRequest(BaseModel):
model: str = Field(default="arrow-preview")
image: QuiverImageObject = Field(...)
auto_crop: bool | None = Field(default=None)
target_size: int | None = Field(default=None, ge=128, le=4096)
temperature: float | None = Field(default=None, ge=0, le=2)
top_p: float | None = Field(default=None, ge=0, le=1)
presence_penalty: float | None = Field(default=None, ge=-2, le=2)
class QuiverSVGResponseItem(BaseModel):
svg: str = Field(...)
mime_type: str | None = Field(default="image/svg+xml")
class QuiverSVGUsage(BaseModel):
total_tokens: int | None = Field(default=None)
input_tokens: int | None = Field(default=None)
output_tokens: int | None = Field(default=None)
class QuiverSVGResponse(BaseModel):
id: str | None = Field(default=None)
created: int | None = Field(default=None)
data: list[QuiverSVGResponseItem] = Field(...)
usage: QuiverSVGUsage | None = Field(default=None)

View File

@@ -47,6 +47,10 @@ SEEDREAM_MODELS = {
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
logger = logging.getLogger(__name__)
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
if response.error:
@@ -135,6 +139,7 @@ class ByteDanceImageNode(IO.ComfyNode):
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03}""",
),
is_deprecated=True,
)
@classmethod
@@ -942,7 +947,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
]
return await process_video_task(
cls,
payload=Image2VideoTaskCreationRequest(model=model, content=x),
payload=Image2VideoTaskCreationRequest(model=model, content=x, generate_audio=None),
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
)
@@ -952,6 +957,12 @@ async def process_video_task(
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
estimated_duration: int | None,
) -> IO.NodeOutput:
if payload.model in DEPRECATED_MODELS:
logger.warning(
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
payload.model,
)
initial_response = await sync_op(
cls,
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),

View File

@@ -63,7 +63,7 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
$m := widgets.model;
$r := widgets.resolution;
$isFlash := $contains($m, "nano banana 2");
$flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123};
$flashPrices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
$proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24};
$prices := $isFlash ? $flashPrices : $proPrices;
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
@@ -188,10 +188,12 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts])
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
async def get_image_from_response(response: GeminiGenerateContentResponse, thought: bool = False) -> Input.Image:
image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/*")
for part in parts:
if (part.thought is True) != thought:
continue
if part.inlineData:
image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
@@ -931,6 +933,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
outputs=[
IO.Image.Output(),
IO.String.Output(),
IO.Image.Output(
display_name="thought_image",
tooltip="First image from the model's thinking process. "
"Only available with thinking_level HIGH and IMAGE+TEXT modality.",
),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -992,7 +999,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
return IO.NodeOutput(
await get_image_from_response(response),
get_text_from_response(response),
await get_image_from_response(response, thought=True),
)
class GeminiExtension(ComfyExtension):

View File

@@ -0,0 +1,291 @@
from io import BytesIO
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.quiver import (
QuiverImageObject,
QuiverImageToSVGRequest,
QuiverSVGResponse,
QuiverTextToSVGRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
sync_op,
upload_image_to_comfyapi,
validate_string,
)
from comfy_extras.nodes_images import SVG
class QuiverTextToSVGNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="QuiverTextToSVGNode",
display_name="Quiver Text to SVG",
category="api node/image/Quiver",
description="Generate an SVG from a text prompt using Quiver AI.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired SVG output.",
),
IO.String.Input(
"instructions",
multiline=True,
default="",
tooltip="Additional style or formatting guidance.",
optional=True,
),
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplatePrefix(
IO.Image.Input("image"),
prefix="ref_",
min=0,
max=4,
),
tooltip="Up to 4 reference images to guide the generation.",
optional=True,
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"arrow-preview",
[
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Randomness control. Higher values increase randomness.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=1.0,
min=0.05,
max=1.0,
step=0.05,
display_mode=IO.NumberDisplay.slider,
tooltip="Nucleus sampling parameter.",
advanced=True,
),
IO.Float.Input(
"presence_penalty",
default=0.0,
min=-2.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Token presence penalty.",
advanced=True,
),
],
),
],
tooltip="Model to use for SVG generation.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.SVG.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.429}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
instructions: str = None,
reference_images: IO.Autogrow.Type = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1)
references = None
if reference_images:
references = []
for key in reference_images:
url = await upload_image_to_comfyapi(cls, reference_images[key])
references.append(QuiverImageObject(url=url))
if len(references) > 4:
raise ValueError("Maximum 4 reference images are allowed.")
instructions_val = instructions.strip() if instructions else None
if instructions_val == "":
instructions_val = None
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/quiver/v1/svgs/generations", method="POST"),
response_model=QuiverSVGResponse,
data=QuiverTextToSVGRequest(
model=model["model"],
prompt=prompt,
instructions=instructions_val,
references=references,
temperature=model.get("temperature"),
top_p=model.get("top_p"),
presence_penalty=model.get("presence_penalty"),
),
)
svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data]
return IO.NodeOutput(SVG(svg_data))
class QuiverImageToSVGNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="QuiverImageToSVGNode",
display_name="Quiver Image to SVG",
category="api node/image/Quiver",
description="Vectorize a raster image into SVG using Quiver AI.",
inputs=[
IO.Image.Input(
"image",
tooltip="Input image to vectorize.",
),
IO.Boolean.Input(
"auto_crop",
default=False,
tooltip="Automatically crop to the dominant subject.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"arrow-preview",
[
IO.Int.Input(
"target_size",
default=1024,
min=128,
max=4096,
tooltip="Square resize target in pixels.",
),
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Randomness control. Higher values increase randomness.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=1.0,
min=0.05,
max=1.0,
step=0.05,
display_mode=IO.NumberDisplay.slider,
tooltip="Nucleus sampling parameter.",
advanced=True,
),
IO.Float.Input(
"presence_penalty",
default=0.0,
min=-2.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Token presence penalty.",
advanced=True,
),
],
),
],
tooltip="Model to use for SVG vectorization.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.SVG.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.429}""",
),
)
@classmethod
async def execute(
cls,
image,
auto_crop: bool,
model: dict,
seed: int,
) -> IO.NodeOutput:
image_url = await upload_image_to_comfyapi(cls, image)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/quiver/v1/svgs/vectorizations", method="POST"),
response_model=QuiverSVGResponse,
data=QuiverImageToSVGRequest(
model=model["model"],
image=QuiverImageObject(url=image_url),
auto_crop=auto_crop if auto_crop else None,
target_size=model.get("target_size"),
temperature=model.get("temperature"),
top_p=model.get("top_p"),
presence_penalty=model.get("presence_penalty"),
),
)
svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data]
return IO.NodeOutput(SVG(svg_data))
class QuiverExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
QuiverTextToSVGNode,
QuiverImageToSVGNode,
]
async def comfy_entrypoint() -> QuiverExtension:
return QuiverExtension()

View File

@@ -3,6 +3,7 @@ from typing_extensions import override
import comfy.model_management
from comfy_api.latest import ComfyExtension, io
import torch
class Canny(io.ComfyNode):
@@ -29,8 +30,8 @@ class Canny(io.ComfyNode):
@classmethod
def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput:
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
output = canny(image.to(device=comfy.model_management.get_torch_device(), dtype=torch.float32).movedim(-1, 1), low_threshold, high_threshold)
img_out = output[1].to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()).repeat(1, 3, 1, 1).movedim(1, -1)
return io.NodeOutput(img_out)

View File

@@ -27,8 +27,8 @@ class ContextWindowsManualNode(io.ComfyNode):
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
],
outputs=[
io.Model.Output(tooltip="The model with context windows applied during sampling."),

View File

@@ -8,6 +8,36 @@ from comfy_api.latest import _io
MISSING = object()
class BranchNode(io.ComfyNode):
@classmethod
def define_schema(cls):
mtemplate = io.MatchType.Template("switch")
minput = io.MatchType.Input("branch", template=mtemplate, lazy=True, optional=True)
template = _io.Autogrow.TemplatePrefix(input=minput, prefix="branch", min=1, max=10)
return io.Schema(
node_id="BranchNode",
display_name="Branch",
category="logic",
is_experimental=True,
inputs=[
io.Int.Input("branch"),
_io.Autogrow.Input("autogrow", template=template)
],
outputs=[
io.MatchType.Output(template=mtemplate, display_name="output"),
],
)
@classmethod
def check_lazy_status(cls, branch, autogrow):
print('lazy', branch)
return ['autogrow.' + list(autogrow.keys())[branch]]
@classmethod
def execute(cls, branch, autogrow) -> io.NodeOutput:
print(branch)
return list(autogrow.values())[branch],
class SwitchNode(io.ComfyNode):
@classmethod
def define_schema(cls):
@@ -268,6 +298,7 @@ class LogicExtension(ComfyExtension):
# AutogrowPrefixTestNode,
# ComboOutputTestNode,
# InvertBooleanNode,
BranchNode,
]
async def comfy_entrypoint() -> LogicExtension:

View File

@@ -1,86 +0,0 @@
from __future__ import annotations
from inspect import cleandoc
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
import comfy.multigpu
class MultiGPUWorkUnitsNode:
"""
Prepares model to have sampling accelerated via splitting work units.
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
Other than those exceptions, this node can be placed in any order.
"""
NodeId = "MultiGPU_WorkUnits"
NodeName = "MultiGPU Work Units"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"max_gpus" : ("INT", {"default": 8, "min": 1, "step": 1}),
},
"optional": {
"gpu_options": ("GPU_OPTIONS",)
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "init_multigpu"
CATEGORY = "advanced/multigpu"
DESCRIPTION = cleandoc(__doc__)
def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup=None):
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, gpu_options, reuse_loaded=True)
return (model,)
class MultiGPUOptionsNode:
"""
Select the relative speed of GPUs in the special case they have significantly different performance from one another.
"""
NodeId = "MultiGPU_Options"
NodeName = "MultiGPU Options"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"device_index": ("INT", {"default": 0, "min": 0, "max": 64}),
"relative_speed": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.01})
},
"optional": {
"gpu_options": ("GPU_OPTIONS",)
}
}
RETURN_TYPES = ("GPU_OPTIONS",)
FUNCTION = "create_gpu_options"
CATEGORY = "advanced/multigpu"
DESCRIPTION = cleandoc(__doc__)
def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup=None):
if not gpu_options:
gpu_options = comfy.multigpu.GPUOptionsGroup()
gpu_options.clone()
opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
gpu_options.add(opt)
return (gpu_options,)
node_list = [
MultiGPUWorkUnitsNode,
MultiGPUOptionsNode
]
NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
for node in node_list:
NODE_CLASS_MAPPINGS[node.NodeId] = node
NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.17.0"
__version__ = "0.18.0"

View File

@@ -1966,9 +1966,11 @@ class EmptyImage:
CATEGORY = "image"
def generate(self, width, height, batch_size=1, color=0):
r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF)
g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF)
b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF)
dtype = comfy.model_management.intermediate_dtype()
device = comfy.model_management.intermediate_device()
r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF, device=device, dtype=dtype)
g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF, device=device, dtype=dtype)
b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF, device=device, dtype=dtype)
return (torch.cat((r, g, b), dim=-1), )
class ImagePadForOutpaint:
@@ -2410,7 +2412,6 @@ async def init_builtin_extra_nodes():
"nodes_lt_audio.py",
"nodes_lt.py",
"nodes_hooks.py",
"nodes_multigpu.py",
"nodes_load_3d.py",
"nodes_cosmos.py",
"nodes_video.py",

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.17.0"
version = "0.18.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@@ -1,4 +1,4 @@
comfyui-frontend-package==1.41.20
comfyui-frontend-package==1.41.21
comfyui-workflow-templates==0.9.26
comfyui-embedded-docs==0.4.3
torch