Compare commits

..

40 Commits

Author SHA1 Message Date
Dante
5661041028 Merge branch 'master' into feature/number-convert-node 2026-03-24 07:23:38 +09:00
Jukka Seppänen
e87858e974 feat: LTX2: Support reference audio (ID-LoRA) (#13111) 2026-03-23 18:22:24 -04:00
Dr.Lt.Data
da6edb5a4e bump manager version to 4.1b8 (#13108) 2026-03-23 12:59:21 -04:00
comfyanonymous
6265a239f3 Add warning for users who disable dynamic vram. (#13113) 2026-03-22 18:46:18 -04:00
Talmaj
d49420b3c7 LongCat-Image edit (#13003) 2026-03-21 23:51:05 -04:00
comfyanonymous
ebf6b52e32 ComfyUI v0.18.1 2026-03-21 22:32:16 -04:00
rattus
25b6d1d629 wan: vae: Fix light/color change (#13101)
There was an issue where the resample split was too early and dropped one
of the rolling convolutions a frame early. This is most noticable as a
lighting/color change between pixel frames 5->6 (latent 2->3), or as a
lighting change between the first and last frame in an FLF wan flow.
2026-03-21 18:44:35 -04:00
comfyanonymous
11c15d8832 Fix fp16 intermediates giving different results. (#13100) 2026-03-21 17:53:25 -04:00
comfyanonymous
b5d32e6ad2 Fix sampling issue with fp16 intermediates. (#13099) 2026-03-21 17:47:42 -04:00
comfyanonymous
a11f68dd3b Fix canny node not working with fp16. (#13085) 2026-03-20 23:15:50 -04:00
comfyanonymous
dc719cde9c ComfyUI version 0.18.0 2026-03-20 20:09:15 -04:00
Jedrzej Kosinski
87cda1fc25 Move inline comfy.context_windows imports to top-level in model_base.py (#13083)
The recent PR that added resize_cond_for_context_window methods to
model classes used inline 'import comfy.context_windows' in each
method body. This moves that import to the top-level import section,
replacing 4 duplicate inline imports with a single top-level one.
2026-03-20 20:03:42 -04:00
comfyanonymous
45d5c83a30 Make EmptyImage node follow intermediate device/dtype. (#13079) 2026-03-20 16:08:26 -04:00
Alexander Piskun
c646d211be feat(api-nodes): add Quiver SVG nodes (#13047) 2026-03-20 12:23:16 -07:00
drozbay
589228e671 Add slice_cond and per-model context window cond resizing (#12645)
* Add slice_cond and per-model context window cond resizing

* Fix cond_value.size() call in context window cond resizing

* Expose additional advanced inputs for ContextWindowsManualNode

Necessary for WanAnimate context windows workflow, which needs cond_retain_index_list = 0 to work properly with its reference input.

---------
2026-03-19 20:42:42 -07:00
Alexander Piskun
e4455fd43a [API Nodes] mark seedream-3-0-t2i and seedance-1-0-lite models as deprecated (#13060)
* chore(api-nodes): mark seedream-3-0-t2i and seedance-1-0-lite models as deprecated

* fix(api-nodes): fixed old regression in the ByteDanceImageReference node

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-03-19 20:05:01 -07:00
rattus
f49856af57 ltx: vae: Fix missing init variable (#13074)
Forgot to push this ammendment. Previous test results apply to this.
2026-03-19 22:34:58 -04:00
rattus
82b868a45a Fix VRAM leak in tiler fallback in video VAEs (#13073)
* sd: soft_empty_cache on tiler fallback

This doesnt cost a lot and creates the expected VRAM reduction in
resource monitors when you fallback to tiler.

* wan: vae: Don't recursion in local fns (move run_up)

Moved Decoder3d’s recursive run_up out of forward into a class
method to avoid nested closure self-reference cycles. This avoids
cyclic garbage that delays garbage of tensors which in turn delays
VRAM release before tiled fallback.

* ltx: vae: Don't recursion in local fns (move run_up)

Mov the recursive run_up out of forward into a class
method to avoid nested closure self-reference cycles. This avoids
cyclic garbage that delays garbage of tensors which in turn delays
VRAM release before tiled fallback.
2026-03-19 22:30:27 -04:00
comfyanonymous
8458ae2686 Revert "fix: run text encoders on MPS GPU instead of CPU for Apple Silicon (#…" (#13070)
This reverts commit b941913f1d.
2026-03-19 15:27:55 -04:00
Jukka Seppänen
fd0261d2bc Reduce tiled decode peak memory (#13050) 2026-03-19 13:29:34 -04:00
rattus
ab14541ef7 memory: Add more exclusion criteria to pinned read (#13067) 2026-03-19 10:03:20 -07:00
rattus
6589562ae3 ltx: vae: implement chunked encoder + CPU IO chunking (Big VRAM reductions) (#13062)
* ltx: vae: add cache state to downsample block

* ltx: vae: Add time stride awareness to causal_conv_3d

* ltx: vae: Automate truncation for encoder

Other VAEs just truncate without error. Do the same.

* sd/ltx: Make chunked_io a flag in its own right

Taking this bi-direcitonal, so make it a for-purpose named flag.

* ltx: vae: implement chunked encoder + CPU IO chunking

People are doing things with big frame counts in LTX including V2V
flows. Implement the time-chunked encoder to keep the VRAM down, with
the converse of the new CPU pre-allocation technique, where the chunks
are brought from the CPU JIT.

* ltx: vae-encode: round chunk sizes more strictly

Only powers of 2 and multiple of 8 are valid due to cache slicing.
2026-03-19 10:01:12 -07:00
rattus
fabed694a2 ltx: vae: implement chunked encoder + CPU IO chunking (Big VRAM reductions) (#13062)
* ltx: vae: add cache state to downsample block

* ltx: vae: Add time stride awareness to causal_conv_3d

* ltx: vae: Automate truncation for encoder

Other VAEs just truncate without error. Do the same.

* sd/ltx: Make chunked_io a flag in its own right

Taking this bi-direcitonal, so make it a for-purpose named flag.

* ltx: vae: implement chunked encoder + CPU IO chunking

People are doing things with big frame counts in LTX including V2V
flows. Implement the time-chunked encoder to keep the VRAM down, with
the converse of the new CPU pre-allocation technique, where the chunks
are brought from the CPU JIT.

* ltx: vae-encode: round chunk sizes more strictly

Only powers of 2 and multiple of 8 are valid due to cache slicing.
2026-03-19 09:58:47 -07:00
dante01yoon
2a6e3dc7a8 Add isfinite guard, exception chaining, and unit tests for Number Convert node
- Add math.isfinite() check to prevent int() crash on inf/nan string inputs
- Use 'from None' for cleaner exception chaining on string parse failure
- Add 21 unit tests covering all input types and error paths
2026-03-19 10:41:02 +09:00
comfyanonymous
f6b869d7d3 fp16 intermediates doen't work for some text enc models. (#13056) 2026-03-18 19:42:28 -04:00
comfyanonymous
56ff88f951 Fix regression. (#13053) 2026-03-18 18:35:25 -04:00
Jukka Seppänen
9fff091f35 Further Reduce LTX VAE decode peak RAM usage (#13052) 2026-03-18 18:32:26 -04:00
comfyanonymous
dcd659590f Make more intermediate values follow the intermediate dtype. (#13051) 2026-03-18 18:14:18 -04:00
Alexander Brown
b67ed2a45f Update comfyui-frontend-package version to 1.41.21 (#13035) 2026-03-18 16:36:39 -04:00
Alexander Piskun
06957022d4 fix(api-nodes): add support for "thought_image" in Nano Banana 2 and corrected price badges (#13038) 2026-03-18 10:21:58 -07:00
dante01yoon
dae107e430 Register nodes_number_convert.py in extras_files list
Without this entry in nodes.py, the Number Convert node file
would not be discovered and loaded at startup.
2026-03-18 19:56:43 +09:00
dante01yoon
82cf5d88c2 Add Number Convert node for unified numeric type conversion
Consolidates fragmented IntToFloat/FloatToInt nodes (previously only
available via third-party packs like ComfyMath, FillNodes, etc.) into
a single core node.

- Single input accepting INT, FLOAT, STRING, and BOOL types
- Two outputs: FLOAT and INT
- Conversion: bool→0/1, string→parsed number, float↔int standard cast
- Follows Math Expression node patterns (comfy_api, io.Schema, etc.)

Refs: COM-16925
2026-03-18 19:56:43 +09:00
Anton Bukov
b941913f1d fix: run text encoders on MPS GPU instead of CPU for Apple Silicon (#12809)
On Apple Silicon, `vram_state` is set to `VRAMState.SHARED` because
CPU and GPU share unified memory. However, `text_encoder_device()`
only checked for `HIGH_VRAM` and `NORMAL_VRAM`, causing all text
encoders to fall back to CPU on MPS devices.

Adding `VRAMState.SHARED` to the condition allows non-quantized text
encoders (e.g. bf16 Gemma 3 12B) to run on the MPS GPU, providing
significant speedup for text encoding and prompt generation.

Note: quantized models (fp4/fp8) that use float8_e4m3fn internally
will still fall back to CPU via the `supports_cast()` check in
`CLIP.__init__()`, since MPS does not support fp8 dtypes.
2026-03-17 21:21:32 -04:00
rattus
cad24ce262 cascade: remove dead weight init code (#13026)
This weight init process is fully shadowed be the weight load and
doesnt work in dynamic_vram were the weight allocation is deferred.
2026-03-17 20:59:10 -04:00
comfyanonymous
68d542cc06 Fix case where pixel space VAE could cause issues. (#13030) 2026-03-17 20:46:22 -04:00
Jukka Seppänen
735a0465e5 Inplace VAE output processing to reduce peak RAM consumption. (#13028) 2026-03-17 20:20:49 -04:00
Dr.Lt.Data
8b9d039f26 bump manager version to 4.1b6 (#13022) 2026-03-17 18:17:03 -04:00
rattus
035414ede4 Reduce WAN VAE VRAM, Save use cases for OOM/Tiler (#13014)
* wan: vae: encoder: Add feature cache layer that corks singles

If a downsample only gives you a single frame, save it to the feature
cache and return nothing to the top level. This increases the
efficiency of cacheability, but also prepares support for going two
by two rather than four by four on the frames.

* wan: remove all concatentation with the feature cache

The loopers are now responsible for ensuring that non-final frames are
processes at least two-by-two, elimiating the need for this cat case.

* wan: vae: recurse and chunk for 2+2 frames on decode

Avoid having to clone off slices of 4 frame chunks and reduce the size
of the big 6 frame convolutions down to 4. Save the VRAMs.

* wan: encode frames 2x2.

Reduce VRAM usage greatly by encoding frames 2 at a time rather than
4.

* wan: vae: remove cloning

The loopers now control the chunking such there is noever more than 2
frames, so just cache these slices directly and avoid the clone
allocations completely.

* wan: vae: free consumer caller tensors on recursion

* wan: vae: restyle a little to match LTX
2026-03-17 17:34:39 -04:00
rattus
1a157e1f97 Reduce LTX VAE VRAM usage and save use cases from OOMs/Tiler (#13013)
* ltx: vae: scale the chunk size with the users VRAM

Scale this linearly down for users with low VRAM.

* ltx: vae: free non-chunking recursive intermediates

* ltx: vae: cleanup some intermediates

The conv layer can be the VRAM peak and it does a torch.cat. So cleanup
the pieces of the cat. Also clear our the cache ASAP as each layer detect
its end as this VAE surges in VRAM at the end due to the ended padding
increasing the size of the final frame convolutions off-the-books to
the chunker. So if all the earlier layers free up their cache it can
offset that surge.

Its a fragmentation nightmare, and the chance of it having to recache the
pyt allocator is very high, but you wont OOM.
2026-03-17 17:32:43 -04:00
Christian Byrne
ed7c2c6579 Mark weight_dtype as advanced input in Load Diffusion Model node (#12769)
Mark the weight_dtype parameter in UNETLoader (Load Diffusion Model) as
an advanced input to reduce UI complexity for new users. The parameter
is now hidden behind an expandable Advanced section, matching the
pattern used for other advanced inputs like device, tile_size, and
overlap.

Amp-Thread-ID: https://ampcode.com/threads/T-019cbaf1-d3c0-718e-a325-318baba86dec
2026-03-17 07:24:00 -07:00
33 changed files with 1139 additions and 226 deletions

View File

@@ -93,6 +93,50 @@ class IndexListCallbacks:
return {}
def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, device, temporal_dim: int, temporal_scale: int=1, temporal_offset: int=0, retain_index_list: list[int]=[]):
if not (hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor)):
return None
cond_tensor = cond_value.cond
if temporal_dim >= cond_tensor.ndim:
return None
cond_size = cond_tensor.size(temporal_dim)
if temporal_scale == 1:
expected_size = x_in.size(window.dim) - temporal_offset
if cond_size != expected_size:
return None
if temporal_offset == 0 and temporal_scale == 1:
sliced = window.get_tensor(cond_tensor, device, dim=temporal_dim, retain_index_list=retain_index_list)
return cond_value._copy_with(sliced)
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
if temporal_offset > 0:
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
indices = [i for i in indices if 0 <= i]
else:
indices = list(window.index_list)
if not indices:
return None
if temporal_scale > 1:
scaled = []
for i in indices:
for k in range(temporal_scale):
si = i * temporal_scale + k
if si < cond_size:
scaled.append(si)
indices = scaled
if not indices:
return None
idx = tuple([slice(None)] * temporal_dim + [indices])
sliced = cond_tensor[idx].to(device)
return cond_value._copy_with(sliced)
@dataclass
class ContextSchedule:
name: str
@@ -177,10 +221,17 @@ class IndexListContextHandler(ContextHandlerABC):
new_cond_item[cond_key] = result
handled = True
break
if not handled and self._model is not None:
result = self._model.resize_cond_for_context_window(
cond_key, cond_value, window, x_in, device,
retain_index_list=self.cond_retain_index_list)
if result is not None:
new_cond_item[cond_key] = result
handled = True
if handled:
continue
if isinstance(cond_value, torch.Tensor):
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
if (self.dim < cond_value.ndim and cond_value.size(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
# Handle audio_embed (temporal dim is 1)
@@ -224,6 +275,7 @@ class IndexListContextHandler(ContextHandlerABC):
return context_windows
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
self._model = model
self.set_step(timestep, model_options)
context_windows = self.get_context_windows(model, x_in, model_options)
enumerated_context_windows = list(enumerate(context_windows))

View File

@@ -136,16 +136,7 @@ class ResBlock(nn.Module):
ops.Linear(c_hidden, c),
)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
# Init weights
def _basic_init(module):
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False)
def _norm(self, x, norm):
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

View File

@@ -386,7 +386,7 @@ class Flux(nn.Module):
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, transformer_options=transformer_options)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
ref_num_tokens.append(kontext.shape[1])

View File

@@ -681,6 +681,33 @@ class LTXAVModel(LTXVModel):
additional_args["has_spatial_mask"] = has_spatial_mask
ax, a_latent_coords = self.a_patchifier.patchify(ax)
# Inject reference audio for ID-LoRA in-context conditioning
ref_audio = kwargs.get("ref_audio", None)
ref_audio_seq_len = 0
if ref_audio is not None:
ref_tokens = ref_audio["tokens"].to(dtype=ax.dtype, device=ax.device)
if ref_tokens.shape[0] < ax.shape[0]:
ref_tokens = ref_tokens.expand(ax.shape[0], -1, -1)
ref_audio_seq_len = ref_tokens.shape[1]
B = ax.shape[0]
# Compute negative temporal positions matching ID-LoRA convention:
# offset by -(end_of_last_token + time_per_latent) so reference ends just before t=0
p = self.a_patchifier
tpl = p.hop_length * p.audio_latent_downsample_factor / p.sample_rate
ref_start = p._get_audio_latent_time_in_sec(0, ref_audio_seq_len, torch.float32, ax.device)
ref_end = p._get_audio_latent_time_in_sec(1, ref_audio_seq_len + 1, torch.float32, ax.device)
time_offset = ref_end[-1].item() + tpl
ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
ref_pos = torch.stack([ref_start, ref_end], dim=-1)
additional_args["ref_audio_seq_len"] = ref_audio_seq_len
additional_args["target_audio_seq_len"] = ax.shape[1]
ax = torch.cat([ref_tokens, ax], dim=1)
a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2)
ax = self.audio_patchify_proj(ax)
# additional_args.update({"av_orig_shape": list(x.shape)})
@@ -721,6 +748,14 @@ class LTXAVModel(LTXVModel):
# Prepare audio timestep
a_timestep = kwargs.get("a_timestep")
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
if ref_audio_seq_len > 0 and a_timestep is not None:
# Reference tokens must have timestep=0, expand scalar/1D timestep to per-token so ref=0 and target=sigma.
target_len = kwargs.get("target_audio_seq_len")
if a_timestep.dim() <= 1:
a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len)
ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], device=a_timestep.device, dtype=a_timestep.dtype)
a_timestep = torch.cat([ref_ts, a_timestep], dim=1)
if a_timestep is not None:
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
a_timestep_flat = a_timestep_scaled.flatten()
@@ -955,6 +990,13 @@ class LTXAVModel(LTXVModel):
v_embedded_timestep = embedded_timestep[0]
a_embedded_timestep = embedded_timestep[1]
# Trim reference audio tokens before unpatchification
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
if ref_audio_seq_len > 0:
ax = ax[:, ref_audio_seq_len:]
if a_embedded_timestep.shape[1] > 1:
a_embedded_timestep = a_embedded_timestep[:, ref_audio_seq_len:]
# Expand compressed video timestep if needed
if isinstance(v_embedded_timestep, CompressedTimestep):
v_embedded_timestep = v_embedded_timestep.expand()

View File

@@ -23,6 +23,11 @@ class CausalConv3d(nn.Module):
self.in_channels = in_channels
self.out_channels = out_channels
if isinstance(stride, int):
self.time_stride = stride
else:
self.time_stride = stride[0]
kernel_size = (kernel_size, kernel_size, kernel_size)
self.time_kernel_size = kernel_size[0]
@@ -58,16 +63,25 @@ class CausalConv3d(nn.Module):
pieces = [ cached, x ]
if is_end and not causal:
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
input_length = sum([piece.shape[2] for piece in pieces])
cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride)
needs_caching = not is_end
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
if needs_caching and cache_length == 0:
self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False)
needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
if needs_caching and x.shape[2] >= cache_length:
needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
x = torch.cat(pieces, dim=2)
del pieces
del cached
if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
elif is_end:
self.temporal_cache_state[tid] = (None, True)
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]

View File

@@ -233,10 +233,7 @@ class Encoder(nn.Module):
self.gradient_checkpointing = False
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]:
sample = self.conv_in(sample)
checkpoint_fn = (
@@ -247,10 +244,14 @@ class Encoder(nn.Module):
for down_block in self.down_blocks:
sample = checkpoint_fn(down_block)(sample)
if sample is None or sample.shape[2] == 0:
return None
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if sample is None or sample.shape[2] == 0:
return None
if self.latent_log_var == "uniform":
last_channel = sample[:, -1:, ...]
@@ -282,9 +283,35 @@ class Encoder(nn.Module):
return sample
def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder
frame_size = sample[:, :, :1, :, :].numel() * sample.element_size()
frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels))
outputs = []
samples = [sample[:, :, :1, :, :]]
if sample.shape[2] > 1:
chunk_t = max(2, max_chunk_size // frame_size)
if chunk_t < 4:
chunk_t = 2
elif chunk_t < 8:
chunk_t = 4
else:
chunk_t = (chunk_t // 8) * 8
samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2))
for chunk_idx, chunk in enumerate(samples):
if chunk_idx == len(samples) - 1:
mark_conv3d_ended(self)
chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device)
output = self._forward_chunk(chunk)
if output is not None:
outputs.append(output)
return torch_cat_if_needed(outputs, dim=2)
def forward(self, *args, **kwargs):
#No encoder support so just flag the end so it doesnt use the cache.
mark_conv3d_ended(self)
try:
return self.forward_orig(*args, **kwargs)
finally:
@@ -297,7 +324,23 @@ class Encoder(nn.Module):
module.temporal_cache_state.pop(tid, None)
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3
MIN_CHUNK_SIZE = 32 * 1024 ** 2
MAX_CHUNK_SIZE = 128 * 1024 ** 2
def get_max_chunk_size(device: torch.device) -> int:
total_memory = comfy.model_management.get_total_memory(dev=device)
if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING:
return MIN_CHUNK_SIZE
if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING:
return MAX_CHUNK_SIZE
interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / (
MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING
)
return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE))
class Decoder(nn.Module):
r"""
@@ -457,6 +500,17 @@ class Decoder(nn.Module):
self.gradient_checkpointing = False
# Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset)
ts, hs, ws, to = 1, 1, 1, 0
for block in self.up_blocks:
if isinstance(block, DepthToSpaceUpsample):
ts *= block.stride[0]
hs *= block.stride[1]
ws *= block.stride[2]
if block.stride[0] > 1:
to = to * block.stride[0] + 1
self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to)
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
@@ -478,11 +532,62 @@ class Decoder(nn.Module):
)
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
def decode_output_shape(self, input_shape):
c, (ts, hs, ws), to = self._output_scale
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size):
sample = sample_ref[0]
sample_ref[0] = None
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
t = sample.shape[2]
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
output_offset[0] += t
return
up_block = self.up_blocks[idx]
if ended:
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
if sample is None or sample.shape[2] == 0:
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
if num_chunks == 1:
# when we are not chunking, detach our x so the callee can free it as soon as they are done
next_sample_ref = [sample]
del sample
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
return
else:
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
for chunk_idx, sample1 in enumerate(samples):
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
def forward_orig(
self,
sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None,
output_buffer: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0]
@@ -497,6 +602,7 @@ class Decoder(nn.Module):
)
timestep_shift_scale = None
scaled_timestep = None
if self.timestep_conditioning:
assert (
timestep is not None
@@ -524,48 +630,18 @@ class Decoder(nn.Module):
)
timestep_shift_scale = ada_values.unbind(dim=1)
output = []
if output_buffer is None:
output_buffer = torch.empty(
self.decode_output_shape(sample.shape),
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
)
output_offset = [0]
def run_up(idx, sample, ended):
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
output.append(sample.to(comfy.model_management.intermediate_device()))
return
max_chunk_size = get_max_chunk_size(sample.device)
up_block = self.up_blocks[idx]
if (ended):
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
if sample is None or sample.shape[2] == 0:
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
run_up(0, sample, True)
sample = torch.cat(output, dim=2)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample
return output_buffer
def forward(self, *args, **kwargs):
try:
@@ -689,12 +765,25 @@ class SpaceToDepthDownsample(nn.Module):
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.temporal_cache_state = {}
def forward(self, x, causal: bool = True):
if self.stride[0] == 2:
tid = threading.get_ident()
cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None))
if cached_input is not None:
x = torch_cat_if_needed([cached_input, x], dim=2)
cached_input = None
if self.stride[0] == 2 and pad_first:
x = torch.cat(
[x[:, :, :1, :, :], x], dim=2
) # duplicate first frames for padding
pad_first = False
if x.shape[2] < self.stride[0]:
cached_input = x
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
return None
# skip connection
x_in = rearrange(
@@ -709,15 +798,26 @@ class SpaceToDepthDownsample(nn.Module):
# conv
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2 and x.shape[2] == 1:
if cached_x is not None:
x = torch_cat_if_needed([cached_x, x], dim=2)
cached_x = None
else:
cached_x = x
x = None
x = x + x_in
if x is not None:
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
cached = add_exchange_cache(x, cached, x_in, dim=2)
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
return x
@@ -1050,6 +1150,8 @@ class processor(nn.Module):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module):
comfy_has_chunked_io = True
def __init__(self, version=0, config=None):
super().__init__()
@@ -1192,14 +1294,15 @@ class VideoVAE(nn.Module):
}
return config
def encode(self, x):
frames_count = x.shape[2]
if ((frames_count - 1) % 8) != 0:
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
def encode(self, x, device=None):
x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :]
means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1)
return self.per_channel_statistics.normalize(means)
def decode(self, x):
def decode_output_shape(self, input_shape):
return self.decoder.decode_output_shape(input_shape)
def decode(self, x, output_buffer=None):
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer)

View File

@@ -99,7 +99,7 @@ class Resample(nn.Module):
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if feat_cache is not None:
@@ -109,22 +109,7 @@ class Resample(nn.Module):
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
cache_x = x[:, :, -CACHE_T:, :, :]
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
@@ -145,19 +130,24 @@ class Resample(nn.Module):
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
feat_cache[idx] = x
else:
cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
cache_x = x[:, :, -1:, :, :]
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
deferred_x = feat_cache[idx + 1]
if deferred_x is not None:
x = torch.cat([deferred_x, x], 2)
feat_cache[idx + 1] = None
if x.shape[2] == 1 and not final:
feat_cache[idx + 1] = x
x = None
feat_idx[0] += 2
return x
@@ -177,19 +167,12 @@ class ResidualBlock(nn.Module):
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
old_x = x
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, cache_list=feat_cache, cache_idx=idx)
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -213,7 +196,7 @@ class AttentionBlock(nn.Module):
self.proj = ops.Conv2d(dim, dim, 1)
self.optimized_attention = vae_attention()
def forward(self, x):
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w')
@@ -283,17 +266,10 @@ class Encoder3d(nn.Module):
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = x[:, :, -CACHE_T:, :, :]
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -303,14 +279,16 @@ class Encoder3d(nn.Module):
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
x = layer(x, feat_cache, feat_idx, final=final)
if x is None:
return None
else:
x = layer(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx, final=final)
else:
x = layer(x)
@@ -318,14 +296,7 @@ class Encoder3d(nn.Module):
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -389,18 +360,48 @@ class Decoder3d(nn.Module):
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, output_channels, 3, padding=1))
def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks):
x = x_ref[0]
x_ref[0] = None
if layer_idx >= len(self.upsamples):
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[feat_idx[0]])
feat_cache[feat_idx[0]] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
out_chunks.append(x)
return
layer = self.upsamples[layer_idx]
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 2:
for frame_idx in range(0, x.shape[2], 2):
self.run_up(
layer_idx + 1,
[x[:, :, frame_idx:frame_idx + 2, :, :]],
feat_cache,
feat_idx.copy(),
out_chunks,
)
del x
return
next_x_ref = [x]
del x
self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks)
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = x[:, :, -CACHE_T:, :, :]
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -409,42 +410,21 @@ class Decoder3d(nn.Module):
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
out_chunks = []
self.run_up(0, [x], feat_cache, feat_idx, out_chunks)
return out_chunks
def count_conv3d(model):
def count_cache_layers(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d):
if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'):
count += 1
return count
@@ -482,11 +462,12 @@ class WanVAE(nn.Module):
conv_idx = [0]
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
t = 1 + ((t - 1) // 4) * 4
iter_ = 1 + (t - 1) // 2
feat_map = None
if iter_ > 1:
feat_map = [None] * count_conv3d(self.encoder)
## 对encode输入的x按时间拆分为1、4、4、4....
feat_map = [None] * count_cache_layers(self.encoder)
## 对encode输入的x按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整)
for i in range(iter_):
conv_idx = [0]
if i == 0:
@@ -496,20 +477,23 @@ class WanVAE(nn.Module):
feat_idx=conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
feat_cache=feat_map,
feat_idx=conv_idx)
feat_idx=conv_idx,
final=(i == (iter_ - 1)))
if out_ is None:
continue
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
return mu
def decode(self, z):
conv_idx = [0]
# z: [b,c,t,h,w]
iter_ = z.shape[2]
iter_ = 1 + z.shape[2] // 2
feat_map = None
if iter_ > 1:
feat_map = [None] * count_conv3d(self.decoder)
feat_map = [None] * count_cache_layers(self.decoder)
x = self.conv2(z)
for i in range(iter_):
conv_idx = [0]
@@ -520,8 +504,8 @@ class WanVAE(nn.Module):
feat_idx=conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
feat_cache=feat_map,
feat_idx=conv_idx)
out = torch.cat([out, out_], 2)
return out
out += out_
return torch.cat(out, 2)

View File

@@ -39,7 +39,10 @@ def read_tensor_file_slice_into(tensor, destination):
if (destination.device.type != "cpu"
or file_obj is None
or threading.get_ident() != info.thread_id
or destination.numel() * destination.element_size() < info.size):
or destination.numel() * destination.element_size() < info.size
or tensor.numel() * tensor.element_size() != info.size
or tensor.storage_offset() != 0
or not tensor.is_contiguous()):
return False
if info.size == 0:

View File

@@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch
import logging
import comfy.ldm.lightricks.av_model
import comfy.context_windows
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.cascade.stage_c import StageC
from comfy.ldm.cascade.stage_b import StageB
@@ -285,6 +286,12 @@ class BaseModel(torch.nn.Module):
return data
return None
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
"""Override in subclasses to handle model-specific cond slicing for context windows.
Return a sliced cond object, or None to fall through to default handling.
Use comfy.context_windows.slice_cond() for common cases."""
return None
def extra_conds(self, **kwargs):
out = {}
concat_cond = self.concat_cond(**kwargs)
@@ -930,9 +937,10 @@ class LongCatImage(Flux):
transformer_options = transformer_options.copy()
rope_opts = transformer_options.get("rope_options", {})
rope_opts = dict(rope_opts)
pe_len = float(c_crossattn.shape[1]) if c_crossattn is not None else 512.0
rope_opts.setdefault("shift_t", 1.0)
rope_opts.setdefault("shift_y", 512.0)
rope_opts.setdefault("shift_x", 512.0)
rope_opts.setdefault("shift_y", pe_len)
rope_opts.setdefault("shift_x", pe_len)
transformer_options["rope_options"] = rope_opts
return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
@@ -1053,6 +1061,10 @@ class LTXAV(BaseModel):
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
ref_audio = kwargs.get("ref_audio", None)
if ref_audio is not None:
out['ref_audio'] = comfy.conds.CONDConstant(ref_audio)
return out
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
@@ -1375,6 +1387,11 @@ class WAN21_Vace(WAN21):
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "vace_context":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN21_Camera(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
@@ -1427,6 +1444,11 @@ class WAN21_HuMo(WAN21):
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "audio_embed":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN22_Animate(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
@@ -1444,6 +1466,13 @@ class WAN22_Animate(WAN21):
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "face_pixel_values":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1)
if cond_key == "pose_latents":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN22_S2V(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
@@ -1480,6 +1509,11 @@ class WAN22_S2V(WAN21):
out['reference_motion'] = reference_motion.shape
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "audio_embed":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN22(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)

View File

@@ -8,12 +8,12 @@ import comfy.nested_tensor
def prepare_noise_inner(latent_image, generator, noise_inds=None):
if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
return torch.randn(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1]+1):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
@@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device())
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device())
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return samples

View File

@@ -985,8 +985,8 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
noise = noise.to(device)
latent_image = latent_image.to(device)
noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device=device, dtype=torch.float32)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
@@ -1028,6 +1028,7 @@ class CFGGuider:
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
else:
denoise_mask = denoise_masks[0]
denoise_mask = denoise_mask.float()
self.conds = {}
for k in self.original_conds:

View File

@@ -455,7 +455,7 @@ class VAE:
self.output_channels = 3
self.pad_channel_value = None
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.process_output = lambda image: image.add_(1.0).div_(2.0).clamp_(0.0, 1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.not_video = False
@@ -951,12 +951,23 @@ class VAE:
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
# Pre-allocate output for VAEs that support direct buffer writes
preallocated = False
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
preallocated = True
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number] = out
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
if preallocated:
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
else:
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number].copy_(out)
del out
self.process_output(pixel_samples[x:x+batch_number])
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
@@ -967,6 +978,7 @@ class VAE:
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
@@ -1027,8 +1039,13 @@ class VAE:
batch_number = max(1, batch_number)
samples = None
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
out = self.first_stage_model.encode(pixels_in, device=self.device)
else:
pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in)
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples[x:x + batch_number] = out
@@ -1043,6 +1060,7 @@ class VAE:
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
if self.latent_dim == 3:
tile = 256
overlap = tile // 4

View File

@@ -46,7 +46,7 @@ class ClipTokenWeightEncoder:
out, pooled = o[:2]
if pooled is not None:
first_pooled = pooled[0:1].to(model_management.intermediate_device())
first_pooled = pooled[0:1].to(device=model_management.intermediate_device())
else:
first_pooled = pooled
@@ -63,16 +63,16 @@ class ClipTokenWeightEncoder:
output.append(z)
if (len(output) == 0):
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled)
else:
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled)
if len(o) > 2:
extra = {}
for k in o[2]:
v = o[2][k]
if k == "attention_mask":
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device())
extra[k] = v
r = r + (extra,)

View File

@@ -1028,12 +1028,19 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
grid = e.get("extra", None)
start = e.get("index")
if position_ids is None:
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long)
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start
len_max = int(grid.max()) // 2
start_next = len_max + start
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
if attention_mask is not None:
# Assign compact sequential positions to attended tokens only,
# skipping over padding so post-padding tokens aren't inflated.
after_mask = attention_mask[0, end:]
text_positions = after_mask.cumsum(0) - 1 + start_next + offset
position_ids[:, end:] = torch.where(after_mask.bool(), text_positions, position_ids[0, end:])
else:
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
position_ids[0, start:end] = start + offset
max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]

View File

@@ -64,7 +64,13 @@ class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer):
return [output]
IMAGE_PAD_TOKEN_ID = 151655
class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
T2I_PREFIX = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
EDIT_PREFIX = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
SUFFIX = "<|im_end|>\n<|im_start|>assistant\n"
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(
embedding_directory=embedding_directory,
@@ -72,10 +78,8 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
name="qwen25_7b",
tokenizer=LongCatImageBaseTokenizer,
)
self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
def tokenize_with_weights(self, text, return_word_ids=False, images=None, **kwargs):
skip_template = False
if text.startswith("<|im_start|>"):
skip_template = True
@@ -90,11 +94,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
)
else:
has_images = images is not None and len(images) > 0
template_prefix = self.EDIT_PREFIX if has_images else self.T2I_PREFIX
prefix_ids = base_tok.tokenizer(
self.longcat_template_prefix, add_special_tokens=False
template_prefix, add_special_tokens=False
)["input_ids"]
suffix_ids = base_tok.tokenizer(
self.longcat_template_suffix, add_special_tokens=False
self.SUFFIX, add_special_tokens=False
)["input_ids"]
prompt_tokens = base_tok.tokenize_with_weights(
@@ -106,6 +113,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
suffix_pairs = [(t, 1.0) for t in suffix_ids]
combined = prefix_pairs + prompt_pairs + suffix_pairs
if has_images:
embed_count = 0
for i in range(len(combined)):
if combined[i][0] == IMAGE_PAD_TOKEN_ID and embed_count < len(images):
combined[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"}, combined[i][1])
embed_count += 1
tokens = {"qwen25_7b": [combined]}
return tokens

View File

@@ -425,4 +425,7 @@ class Qwen2VLVisionTransformer(nn.Module):
hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention)
hidden_states = self.merger(hidden_states)
# Potentially important for spatially precise edits. This is present in the HF implementation.
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]
return hidden_states

View File

@@ -1135,8 +1135,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
pbar.update(1)
continue
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
out = output[b:b+1].zero_()
out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device)
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
@@ -1151,7 +1151,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
upscaled.append(round(get_pos(d, pos)))
ps = function(s_in).to(output_device)
mask = torch.ones_like(ps)
mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device)
for d in range(2, dims + 2):
feather = round(get_scale(d - 2, overlap[d - 2]))
@@ -1174,7 +1174,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
if pbar is not None:
pbar.update(1)
output[b:b+1] = out/out_div
out.div_(out_div)
return output
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):

View File

@@ -67,6 +67,7 @@ class GeminiPart(BaseModel):
inlineData: GeminiInlineData | None = Field(None)
fileData: GeminiFileData | None = Field(None)
text: str | None = Field(None)
thought: bool | None = Field(None)
class GeminiTextPart(BaseModel):

View File

@@ -0,0 +1,43 @@
from pydantic import BaseModel, Field
class QuiverImageObject(BaseModel):
url: str = Field(...)
class QuiverTextToSVGRequest(BaseModel):
model: str = Field(default="arrow-preview")
prompt: str = Field(...)
instructions: str | None = Field(default=None)
references: list[QuiverImageObject] | None = Field(default=None, max_length=4)
temperature: float | None = Field(default=None, ge=0, le=2)
top_p: float | None = Field(default=None, ge=0, le=1)
presence_penalty: float | None = Field(default=None, ge=-2, le=2)
class QuiverImageToSVGRequest(BaseModel):
model: str = Field(default="arrow-preview")
image: QuiverImageObject = Field(...)
auto_crop: bool | None = Field(default=None)
target_size: int | None = Field(default=None, ge=128, le=4096)
temperature: float | None = Field(default=None, ge=0, le=2)
top_p: float | None = Field(default=None, ge=0, le=1)
presence_penalty: float | None = Field(default=None, ge=-2, le=2)
class QuiverSVGResponseItem(BaseModel):
svg: str = Field(...)
mime_type: str | None = Field(default="image/svg+xml")
class QuiverSVGUsage(BaseModel):
total_tokens: int | None = Field(default=None)
input_tokens: int | None = Field(default=None)
output_tokens: int | None = Field(default=None)
class QuiverSVGResponse(BaseModel):
id: str | None = Field(default=None)
created: int | None = Field(default=None)
data: list[QuiverSVGResponseItem] = Field(...)
usage: QuiverSVGUsage | None = Field(default=None)

View File

@@ -47,6 +47,10 @@ SEEDREAM_MODELS = {
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
logger = logging.getLogger(__name__)
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
if response.error:
@@ -135,6 +139,7 @@ class ByteDanceImageNode(IO.ComfyNode):
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03}""",
),
is_deprecated=True,
)
@classmethod
@@ -942,7 +947,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
]
return await process_video_task(
cls,
payload=Image2VideoTaskCreationRequest(model=model, content=x),
payload=Image2VideoTaskCreationRequest(model=model, content=x, generate_audio=None),
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
)
@@ -952,6 +957,12 @@ async def process_video_task(
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
estimated_duration: int | None,
) -> IO.NodeOutput:
if payload.model in DEPRECATED_MODELS:
logger.warning(
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
payload.model,
)
initial_response = await sync_op(
cls,
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),

View File

@@ -63,7 +63,7 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
$m := widgets.model;
$r := widgets.resolution;
$isFlash := $contains($m, "nano banana 2");
$flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123};
$flashPrices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
$proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24};
$prices := $isFlash ? $flashPrices : $proPrices;
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
@@ -188,10 +188,12 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts])
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
async def get_image_from_response(response: GeminiGenerateContentResponse, thought: bool = False) -> Input.Image:
image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/*")
for part in parts:
if (part.thought is True) != thought:
continue
if part.inlineData:
image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
@@ -931,6 +933,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
outputs=[
IO.Image.Output(),
IO.String.Output(),
IO.Image.Output(
display_name="thought_image",
tooltip="First image from the model's thinking process. "
"Only available with thinking_level HIGH and IMAGE+TEXT modality.",
),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -992,7 +999,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
return IO.NodeOutput(
await get_image_from_response(response),
get_text_from_response(response),
await get_image_from_response(response, thought=True),
)
class GeminiExtension(ComfyExtension):

View File

@@ -0,0 +1,291 @@
from io import BytesIO
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.quiver import (
QuiverImageObject,
QuiverImageToSVGRequest,
QuiverSVGResponse,
QuiverTextToSVGRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
sync_op,
upload_image_to_comfyapi,
validate_string,
)
from comfy_extras.nodes_images import SVG
class QuiverTextToSVGNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="QuiverTextToSVGNode",
display_name="Quiver Text to SVG",
category="api node/image/Quiver",
description="Generate an SVG from a text prompt using Quiver AI.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired SVG output.",
),
IO.String.Input(
"instructions",
multiline=True,
default="",
tooltip="Additional style or formatting guidance.",
optional=True,
),
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplatePrefix(
IO.Image.Input("image"),
prefix="ref_",
min=0,
max=4,
),
tooltip="Up to 4 reference images to guide the generation.",
optional=True,
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"arrow-preview",
[
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Randomness control. Higher values increase randomness.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=1.0,
min=0.05,
max=1.0,
step=0.05,
display_mode=IO.NumberDisplay.slider,
tooltip="Nucleus sampling parameter.",
advanced=True,
),
IO.Float.Input(
"presence_penalty",
default=0.0,
min=-2.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Token presence penalty.",
advanced=True,
),
],
),
],
tooltip="Model to use for SVG generation.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.SVG.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.429}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
instructions: str = None,
reference_images: IO.Autogrow.Type = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1)
references = None
if reference_images:
references = []
for key in reference_images:
url = await upload_image_to_comfyapi(cls, reference_images[key])
references.append(QuiverImageObject(url=url))
if len(references) > 4:
raise ValueError("Maximum 4 reference images are allowed.")
instructions_val = instructions.strip() if instructions else None
if instructions_val == "":
instructions_val = None
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/quiver/v1/svgs/generations", method="POST"),
response_model=QuiverSVGResponse,
data=QuiverTextToSVGRequest(
model=model["model"],
prompt=prompt,
instructions=instructions_val,
references=references,
temperature=model.get("temperature"),
top_p=model.get("top_p"),
presence_penalty=model.get("presence_penalty"),
),
)
svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data]
return IO.NodeOutput(SVG(svg_data))
class QuiverImageToSVGNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="QuiverImageToSVGNode",
display_name="Quiver Image to SVG",
category="api node/image/Quiver",
description="Vectorize a raster image into SVG using Quiver AI.",
inputs=[
IO.Image.Input(
"image",
tooltip="Input image to vectorize.",
),
IO.Boolean.Input(
"auto_crop",
default=False,
tooltip="Automatically crop to the dominant subject.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"arrow-preview",
[
IO.Int.Input(
"target_size",
default=1024,
min=128,
max=4096,
tooltip="Square resize target in pixels.",
),
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Randomness control. Higher values increase randomness.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=1.0,
min=0.05,
max=1.0,
step=0.05,
display_mode=IO.NumberDisplay.slider,
tooltip="Nucleus sampling parameter.",
advanced=True,
),
IO.Float.Input(
"presence_penalty",
default=0.0,
min=-2.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Token presence penalty.",
advanced=True,
),
],
),
],
tooltip="Model to use for SVG vectorization.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.SVG.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.429}""",
),
)
@classmethod
async def execute(
cls,
image,
auto_crop: bool,
model: dict,
seed: int,
) -> IO.NodeOutput:
image_url = await upload_image_to_comfyapi(cls, image)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/quiver/v1/svgs/vectorizations", method="POST"),
response_model=QuiverSVGResponse,
data=QuiverImageToSVGRequest(
model=model["model"],
image=QuiverImageObject(url=image_url),
auto_crop=auto_crop if auto_crop else None,
target_size=model.get("target_size"),
temperature=model.get("temperature"),
top_p=model.get("top_p"),
presence_penalty=model.get("presence_penalty"),
),
)
svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data]
return IO.NodeOutput(SVG(svg_data))
class QuiverExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
QuiverTextToSVGNode,
QuiverImageToSVGNode,
]
async def comfy_entrypoint() -> QuiverExtension:
return QuiverExtension()

View File

@@ -3,6 +3,7 @@ from typing_extensions import override
import comfy.model_management
from comfy_api.latest import ComfyExtension, io
import torch
class Canny(io.ComfyNode):
@@ -29,8 +30,8 @@ class Canny(io.ComfyNode):
@classmethod
def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput:
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
output = canny(image.to(device=comfy.model_management.get_torch_device(), dtype=torch.float32).movedim(-1, 1), low_threshold, high_threshold)
img_out = output[1].to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()).repeat(1, 3, 1, 1).movedim(1, -1)
return io.NodeOutput(img_out)

View File

@@ -27,8 +27,8 @@ class ContextWindowsManualNode(io.ComfyNode):
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
],
outputs=[
io.Model.Output(tooltip="The model with context windows applied during sampling."),

View File

@@ -3,6 +3,7 @@ import node_helpers
import torch
import comfy.model_management
import comfy.model_sampling
import comfy.samplers
import comfy.utils
import math
import numpy as np
@@ -682,6 +683,84 @@ class LTXVSeparateAVLatent(io.ComfyNode):
return io.NodeOutput(video_latent, audio_latent)
class LTXVReferenceAudio(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVReferenceAudio",
display_name="LTXV Reference Audio (ID-LoRA)",
category="conditioning/audio",
description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).",
inputs=[
io.Model.Input("model"),
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Audio.Input("reference_audio", tooltip="Reference audio clip whose speaker identity to transfer. ~5 seconds recommended (training duration). Shorter or longer clips may degrade voice identity transfer."),
io.Vae.Input(id="audio_vae", display_name="Audio VAE", tooltip="LTXV Audio VAE for encoding."),
io.Float.Input("identity_guidance_scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01, tooltip="Strength of identity guidance. Runs an extra forward pass without reference each step to amplify speaker identity. Set to 0 to disable (no extra pass)."),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="Start of the sigma range where identity guidance is active."),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="End of the sigma range where identity guidance is active."),
],
outputs=[
io.Model.Output(),
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
)
@classmethod
def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput:
# Encode reference audio to latents and patchify
audio_latents = audio_vae.encode(reference_audio)
b, c, t, f = audio_latents.shape
ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f)
ref_audio = {"tokens": ref_tokens}
positive = node_helpers.conditioning_set_values(positive, {"ref_audio": ref_audio})
negative = node_helpers.conditioning_set_values(negative, {"ref_audio": ref_audio})
# Patch model with identity guidance
m = model.clone()
scale = identity_guidance_scale
model_sampling = m.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)
def post_cfg_function(args):
if scale == 0:
return args["denoised"]
sigma = args["sigma"]
sigma_ = sigma[0].item()
if sigma_ > sigma_start or sigma_ < sigma_end:
return args["denoised"]
cond_pred = args["cond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
model_options = args["model_options"].copy()
x = args["input"]
# Strip ref_audio from conditioning for the no-reference pass
noref_cond = []
for entry in cond:
new_entry = entry.copy()
mc = new_entry.get("model_conds", {}).copy()
mc.pop("ref_audio", None)
new_entry["model_conds"] = mc
noref_cond.append(new_entry)
(pred_noref,) = comfy.samplers.calc_cond_batch(
args["model"], [noref_cond], x, sigma, model_options
)
return cfg_result + (cond_pred - pred_noref) * scale
m.set_model_sampler_post_cfg_function(post_cfg_function)
return io.NodeOutput(m, positive, negative)
class LtxvExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@@ -697,6 +776,7 @@ class LtxvExtension(ComfyExtension):
LTXVCropGuides,
LTXVConcatAVLatent,
LTXVSeparateAVLatent,
LTXVReferenceAudio,
]

View File

@@ -0,0 +1,79 @@
"""Number Convert node for unified numeric type conversion.
Provides a single node that converts INT, FLOAT, STRING, and BOOL
inputs into FLOAT and INT outputs.
"""
from __future__ import annotations
import math
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class NumberConvertNode(io.ComfyNode):
"""Converts various types to numeric FLOAT and INT outputs."""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="ComfyNumberConvert",
display_name="Number Convert",
category="math",
search_aliases=[
"int to float", "float to int", "number convert",
"int2float", "float2int", "cast", "parse number",
"string to number", "bool to int",
],
inputs=[
io.MultiType.Input(
"value",
[io.Int, io.Float, io.String, io.Boolean],
display_name="value",
),
],
outputs=[
io.Float.Output(display_name="FLOAT"),
io.Int.Output(display_name="INT"),
],
)
@classmethod
def execute(cls, value) -> io.NodeOutput:
if isinstance(value, bool):
float_val = 1.0 if value else 0.0
elif isinstance(value, (int, float)):
float_val = float(value)
elif isinstance(value, str):
text = value.strip()
if not text:
raise ValueError("Cannot convert empty string to number.")
try:
float_val = float(text)
except ValueError:
raise ValueError(
f"Cannot convert string to number: {value!r}"
) from None
else:
raise TypeError(
f"Unsupported input type: {type(value).__name__}"
)
if not math.isfinite(float_val):
raise ValueError(
f"Cannot convert non-finite value to number: {float_val}"
)
return io.NodeOutput(float_val, int(float_val))
class NumberConvertExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [NumberConvertNode]
async def comfy_entrypoint() -> NumberConvertExtension:
return NumberConvertExtension()

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.17.0"
__version__ = "0.18.1"

View File

@@ -471,6 +471,9 @@ if __name__ == "__main__":
if sys.version_info.major == 3 and sys.version_info.minor < 10:
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
if args.disable_dynamic_vram:
logging.warning("Dynamic vram disabled with argument. If you have any issues with dynamic vram enabled please give us a detailed reports as this argument will be removed soon.")
event_loop, _, start_all_func = start_comfyui()
try:
x = start_all_func()

View File

@@ -1 +1 @@
comfyui_manager==4.1b5
comfyui_manager==4.1b8

View File

@@ -1966,9 +1966,11 @@ class EmptyImage:
CATEGORY = "image"
def generate(self, width, height, batch_size=1, color=0):
r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF)
g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF)
b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF)
dtype = comfy.model_management.intermediate_dtype()
device = comfy.model_management.intermediate_device()
r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF, device=device, dtype=dtype)
g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF, device=device, dtype=dtype)
b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF, device=device, dtype=dtype)
return (torch.cat((r, g, b), dim=-1), )
class ImagePadForOutpaint:
@@ -2452,6 +2454,7 @@ async def init_builtin_extra_nodes():
"nodes_nag.py",
"nodes_sdpose.py",
"nodes_math.py",
"nodes_number_convert.py",
"nodes_painter.py",
]

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.17.0"
version = "0.18.1"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@@ -1,4 +1,4 @@
comfyui-frontend-package==1.41.20
comfyui-frontend-package==1.41.21
comfyui-workflow-templates==0.9.26
comfyui-embedded-docs==0.4.3
torch

View File

@@ -0,0 +1,123 @@
import pytest
from unittest.mock import patch, MagicMock
mock_nodes = MagicMock()
mock_nodes.MAX_RESOLUTION = 16384
mock_server = MagicMock()
with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}):
from comfy_extras.nodes_number_convert import NumberConvertNode
class TestNumberConvertExecute:
@staticmethod
def _exec(value) -> object:
return NumberConvertNode.execute(value)
# --- INT input ---
def test_int_input(self):
result = self._exec(42)
assert result[0] == 42.0
assert result[1] == 42
def test_int_zero(self):
result = self._exec(0)
assert result[0] == 0.0
assert result[1] == 0
def test_int_negative(self):
result = self._exec(-7)
assert result[0] == -7.0
assert result[1] == -7
# --- FLOAT input ---
def test_float_input(self):
result = self._exec(3.14)
assert result[0] == 3.14
assert result[1] == 3
def test_float_truncation_toward_zero(self):
result = self._exec(-2.9)
assert result[0] == -2.9
assert result[1] == -2 # int() truncates toward zero, not floor
def test_float_output_type(self):
result = self._exec(5)
assert isinstance(result[0], float)
def test_int_output_type(self):
result = self._exec(5.7)
assert isinstance(result[1], int)
# --- BOOL input ---
def test_bool_true(self):
result = self._exec(True)
assert result[0] == 1.0
assert result[1] == 1
def test_bool_false(self):
result = self._exec(False)
assert result[0] == 0.0
assert result[1] == 0
# --- STRING input ---
def test_string_integer(self):
result = self._exec("42")
assert result[0] == 42.0
assert result[1] == 42
def test_string_float(self):
result = self._exec("3.14")
assert result[0] == 3.14
assert result[1] == 3
def test_string_negative(self):
result = self._exec("-5.5")
assert result[0] == -5.5
assert result[1] == -5
def test_string_with_whitespace(self):
result = self._exec(" 7.0 ")
assert result[0] == 7.0
assert result[1] == 7
def test_string_scientific_notation(self):
result = self._exec("1e3")
assert result[0] == 1000.0
assert result[1] == 1000
# --- STRING error paths ---
def test_empty_string_raises(self):
with pytest.raises(ValueError, match="Cannot convert empty string"):
self._exec("")
def test_whitespace_only_string_raises(self):
with pytest.raises(ValueError, match="Cannot convert empty string"):
self._exec(" ")
def test_non_numeric_string_raises(self):
with pytest.raises(ValueError, match="Cannot convert string to number"):
self._exec("abc")
def test_string_inf_raises(self):
with pytest.raises(ValueError, match="non-finite"):
self._exec("inf")
def test_string_nan_raises(self):
with pytest.raises(ValueError, match="non-finite"):
self._exec("nan")
def test_string_negative_inf_raises(self):
with pytest.raises(ValueError, match="non-finite"):
self._exec("-inf")
# --- Unsupported type ---
def test_unsupported_type_raises(self):
with pytest.raises(TypeError, match="Unsupported input type"):
self._exec([1, 2, 3])