Compare commits

..

37 Commits

Author SHA1 Message Date
bymyself
e4f3d335dc feat: Add VideoSlice node with lazy operations on VideoInput
- Add VideoOp base class and SliceOp in _input/video_types.py
- Add sliced() method to VideoInput that returns a copy with operation appended
- Each subclass applies operations in get_components() and get_frame_count()
- After materialization, VideoFromFile delegates to internal VideoFromComponents
- Add VideoSlice node that uses video.sliced(start_frame, frame_count)
- Add tests for SliceOp, sliced() behavior, and materialization
2026-01-23 20:52:15 -08:00
Omri Marom
d7f3241bf6 qwen_image: propagate attention mask. (#11966) 2026-01-22 20:02:31 -05:00
comfyanonymous
09a2e67151 Support loading flux 2 klein checkpoints saved with SaveCheckpoint. (#12033) 2026-01-22 18:20:48 -05:00
rattus
0fd1b78736 Reduce LTX2 VAE VRAM consumption (#12028)
* causal_video_ae: Remove attention ResNet

This attention_head_dim argument does not exist on this constructor so
this is dead code. Remove as generic attention mid VAE conflicts with
temporal roll.

* ltx-vae: consoldate causal/non-causal code paths

* ltx-vae: add cache rolling adder

* ltx-vae: use cached adder for resnet

* ltx-vae: Implement rolling VAE

Implement a temporal rolling VAE for the LTX2 VAE.

Usually when doing temporal rolling VAEs you can just chunk on time relying
on causality and cache behind you as you go. The LTX VAE is however
non-causal.

So go whole hog and implement per layer run ahead and backpressure between
the decoder layers using recursive state beween the layers.

Operations are ammended with temporal_cache_state{} which they can use to
hold any state then need for partial execution. Convolutions cache their
inputs behind the up to N-1 frames, and skip connections need to cache the
mismatch between convolution input and output that happens due to missing
future (non-causal) input.

Each call to run_up() processes a layer accross a range on input that
may or may not be complete. It goes depth first to process as much as
possible to try and digest frames to the final output ASAP. If layers run
out of input due to convolution losses, they simply return without action
effectively applying back-pressure to the earlier layers. As the earlier
layers do more work and caller deeper, the partial states are reconciled
and output continues to digest depth first as much as possible.

Chunking is done using a size quota rather than a fixed frame length and
any layer can initiate chunking, and multiple layers can chunk at different
granulatiries. This remove the old limitation of always having to process
1 latent frame to entirety and having to hold 8 full decoded frames as
the VRAM peak.
2026-01-22 16:54:18 -05:00
Terry Jia
8490eedadf add ply & 3dgs format in 3d node (#11474) 2026-01-22 09:46:56 -08:00
Alexander Piskun
72f6be1690 chore(api-nodes): rename BriaImage and OpenAIGImage nodes (#12022) 2026-01-21 23:42:04 -08:00
Jukka Seppänen
16b9aabd52 Support Multi/InfiniteTalk (#10179)
* re-init

* Update model_multitalk.py

* whitespace...

* Update model_multitalk.py

* remove print

* this is redundant

* remove import

* Restore preview functionality

* Move block_idx to transformer_options

* Remove LoopingSamplerCustomAdvanced

* Remove looping functionality, keep extension functionality

* Update model_multitalk.py

* Handle ref_attn_mask with separate patch to avoid having to always return q and k from self_attn

* Chunk attention map calculation for multiple speakers to reduce peak VRAM usage

* Update model_multitalk.py

* Add ModelPatch type back

* Fix for latest upstream

* Use DynamicCombo for cleaner node

Basically just so that single_speaker mode hides mask inputs and 2nd audio input

* Update nodes_wan.py
2026-01-21 23:09:48 -05:00
Jukka Seppänen
245f6139b6 More targeted embedding_connector loading for LTX2 text encoder (#11992)
Reduces errors
2026-01-21 23:05:06 -05:00
Jukka Seppänen
3365ad18a5 Support LTX2 tiny vae (taeltx_2) (#11929) 2026-01-21 23:03:51 -05:00
Jedrzej Kosinski
f09904720d Fix for edge case of EasyCache when conditionings change during a sampling run (like with timestep scheduling) (#12020) 2026-01-21 23:01:35 -05:00
comfyanonymous
abe2ec26a6 Support the Anima model. (#12012) 2026-01-21 19:44:28 -05:00
Christian Byrne
bdeac8897e feat: Add search_aliases field to node schema (#12010)
* feat: Add search_aliases field to node schema

Adds `search_aliases` field to improve node discoverability. Users can define alternative search terms for nodes (e.g., "text concat" → StringConcatenate).

Changes:
- Add `search_aliases: list[str]` to V3 Schema
- Add `SEARCH_ALIASES` support for V1 nodes
- Include field in `/object_info` response
- Add aliases to high-priority core nodes

V1 usage:
```python
class MyNode:
    SEARCH_ALIASES = ["alt name", "synonym"]
```

V3 usage:
```python
io.Schema(
    node_id="MyNode",
    search_aliases=["alt name", "synonym"],
    ...
)
```

## Related PRs
- Frontend: Comfy-Org/ComfyUI_frontend#XXXX (draft - merge after this)
- Docs: Comfy-Org/docs#XXXX (draft - merge after stable)

* Propagate search_aliases through V3 Schema.get_v1_info to NodeInfoV1
2026-01-21 15:36:02 -08:00
Alexander Piskun
451af70154 fix(api-nodes-Vidu): allow passing up to 7 subjects in Vidu Reference node (#12002) 2026-01-21 04:03:45 -08:00
Markury
0fc15700be Add LyCoris LoKr MLP layer support for Flux2 (#11997) 2026-01-20 23:18:33 -05:00
comfyanonymous
e755268e7b Config for Qwen 3 0.6B model. (#11998) 2026-01-20 23:08:31 -05:00
Mylo
c4a14df9a3 Dynamically detect chroma radiance patch size (#11991) 2026-01-20 18:46:11 -05:00
Ivan Zorin
965d0ed509 fix: remove normalization of audio in LTX Mel spectrogram creation (#11990)
For LTX Audio VAE, remove normalization of audio during MEL spectrogram creation.
This aligs inference with training and prevents loud audio from being attenuated.
2026-01-20 18:44:28 -05:00
Alexander Piskun
ddc541ffda feat(api-nodes): add WaveSpeed nodes (#11945) 2026-01-20 13:05:40 -08:00
comfyanonymous
8ccc0c94fa Make omni stuff work on regular z image for easier testing. (#11985) 2026-01-20 00:32:00 -05:00
Comfy Org PR Bot
4edb87aa50 Bump comfyui-frontend-package to 1.37.11 (#11976) 2026-01-19 23:57:50 -05:00
ComfyUI Wiki
0fc3b6e3a6 chore: update workflow templates to v0.8.15 (#11984) 2026-01-19 23:17:56 -05:00
comfyanonymous
2108167f9f Support zimage omni base model. (#11979) 2026-01-19 23:17:38 -05:00
comfyanonymous
9d273d3ab1 ComfyUI v0.10.0 2026-01-19 22:40:18 -05:00
comfyanonymous
70c91b8248 Fix #11963 (#11982) 2026-01-19 22:32:40 -05:00
rkfg
0da5a0fe58 Convert mono audio to fake stereo for LTXV VAE encoding (#11965) 2026-01-19 22:12:02 -05:00
comfyanonymous
e0eacb0688 Simpler way to implement the #11980 loras. (#11981) 2026-01-19 22:00:36 -05:00
Jedrzej Kosinski
7458e20465 Make Autogrow validation work properly (#11977)
* In-progress autogrow validation fixes - properly looks at required/optional inputs, now working on the edge case that all inputs are optional and nothing is plugged in (should just be an empty dictionary passed into node)

* Allow autogrow to work with all inputs being optional

* Revert accidentally pushed changes to nodes_logic.py
2026-01-19 16:58:30 -08:00
Jedrzej Kosinski
b931b37e30 feat(api-nodes): add Bria Edit node (#11978)
Co-authored-by: Alexander Piskun <bigcat88@icloud.com>
2026-01-19 16:47:14 -08:00
ComfyUI Wiki
866a4619db chore: update workflow templates to v0.8.14 (#11974) 2026-01-19 14:21:35 -08:00
comfyanonymous
1a72bf2046 Readme update. (#11957) 2026-01-18 19:53:43 -08:00
Alexander Piskun
034fac7054 chore(api-nodes): auto-discover all nodes_*.py files to avoid merge conflicts when adding new API nodes (#11943) 2026-01-17 22:40:39 -08:00
Christian Byrne
a498556d0d feat: add advanced parameter to Input classes for advanced widgets support (#11939)
Add 'advanced' boolean parameter to Input and WidgetInput base classes
and propagate to all typed Input subclasses (Boolean, Int, Float, String,
Combo, MultiCombo, Webcam, MultiType, MatchType, ImageCompare).

When set to True, the frontend will hide these inputs by default in a
collapsible 'Advanced Inputs' section in the right side panel, reducing
visual clutter for power-user options.

This enables nodes to expose advanced configuration options (like encoding
parameters, quality settings, etc.) without overwhelming typical users.

Frontend support: ComfyUI_frontend PR #7812
2026-01-17 19:06:03 -08:00
Alexander Piskun
f7ca41ff62 chore(api-nodes): remove check for pyav>=14.2 in code (it was added to requirements.txt long ago) (#11934) 2026-01-17 18:57:57 -08:00
Alexander Piskun
ac26065e61 chore(api-nodes): remove non-used; extract model to separate files (#11927)
* chore(api-nodes): remove non-used; extract model to separate files

* chore(api-nodes): remove non-needed prefix in filenames
2026-01-17 18:52:45 -08:00
comfyanonymous
190c4416cc Bump comfy-kitchen dependency to version 0.2.7 (#11941) 2026-01-17 21:20:35 -05:00
Theephop
0fd10ffa09 fix: use .cpu() for waveform conversion in AudioFrame creation (#11787) 2026-01-17 20:18:24 -05:00
Alex Butler
00c775950a Update readme rdna3 nightly url (#11937) 2026-01-17 20:18:04 -05:00
92 changed files with 3395 additions and 893 deletions

View File

@@ -108,7 +108,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Works fully offline: core will never download anything unless you want to.
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview) disable with: `--disable-api-nodes`
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
@@ -212,7 +212,7 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
torch 2.4 and above is supported but some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
### Instructions:
@@ -229,7 +229,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
@@ -240,7 +240,7 @@ These have less hardware support than the builds above but they work on windows.
RDNA 3 (RX 7000 series):
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/```
```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-all/```
RDNA 3.5 (Strix halo/Ryzen AI Max+ 365):

202
comfy/ldm/anima/model.py Normal file
View File

@@ -0,0 +1,202 @@
from comfy.ldm.cosmos.predict2 import MiniTrainDIT
import torch
from torch import nn
import torch.nn.functional as F
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed
class RotaryEmbedding(nn.Module):
def __init__(self, head_dim):
super().__init__()
self.rope_theta = 10000
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Attention(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None):
super().__init__()
inner_dim = head_dim * n_heads
self.n_heads = n_heads
self.head_dim = head_dim
self.query_dim = query_dim
self.context_dim = context_dim
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None):
context = x if context is None else context
input_shape = x.shape[:-1]
q_shape = (*input_shape, self.n_heads, self.head_dim)
context_shape = context.shape[:-1]
kv_shape = (*context_shape, self.n_heads, self.head_dim)
query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2)
value_states = self.v_proj(context).view(kv_shape).transpose(1, 2)
if position_embeddings is not None:
assert position_embeddings_context is not None
cos, sin = position_embeddings
query_states = apply_rotary_pos_emb(query_states, cos, sin)
cos, sin = position_embeddings_context
key_states = apply_rotary_pos_emb(key_states, cos, sin)
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
def init_weights(self):
torch.nn.init.zeros_(self.o_proj.weight)
class TransformerBlock(nn.Module):
def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None):
super().__init__()
self.use_self_attn = use_self_attn
if self.use_self_attn:
self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
self.self_attn = Attention(
query_dim=model_dim,
context_dim=model_dim,
n_heads=num_heads,
head_dim=model_dim//num_heads,
device=device,
dtype=dtype,
operations=operations,
)
self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
self.cross_attn = Attention(
query_dim=model_dim,
context_dim=source_dim,
n_heads=num_heads,
head_dim=model_dim//num_heads,
device=device,
dtype=dtype,
operations=operations,
)
self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
self.mlp = nn.Sequential(
operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype),
nn.GELU(),
operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype)
)
def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None):
if self.use_self_attn:
normed = self.norm_self_attn(x)
attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings)
x = x + attn_out
normed = self.norm_cross_attn(x)
attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
x = x + attn_out
x = x + self.mlp(self.norm_mlp(x))
return x
def init_weights(self):
torch.nn.init.zeros_(self.mlp[2].weight)
self.cross_attn.init_weights()
class LLMAdapter(nn.Module):
def __init__(
self,
source_dim=1024,
target_dim=1024,
model_dim=1024,
num_layers=6,
num_heads=16,
use_self_attn=True,
layer_norm=False,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype)
if model_dim != target_dim:
self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype)
else:
self.in_proj = nn.Identity()
self.rotary_emb = RotaryEmbedding(model_dim//num_heads)
self.blocks = nn.ModuleList([
TransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers)
])
self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype)
self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype)
def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None):
if target_attention_mask is not None:
target_attention_mask = target_attention_mask.to(torch.bool)
if target_attention_mask.ndim == 2:
target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1)
if source_attention_mask is not None:
source_attention_mask = source_attention_mask.to(torch.bool)
if source_attention_mask.ndim == 2:
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
x = self.in_proj(self.embed(target_input_ids))
context = source_hidden_states
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
position_embeddings = self.rotary_emb(x, position_ids)
position_embeddings_context = self.rotary_emb(x, position_ids_context)
for block in self.blocks:
x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
return self.norm(self.out_proj(x))
class Anima(MiniTrainDIT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
def preprocess_text_embeds(self, text_embeds, text_ids):
if text_ids is not None:
return self.llm_adapter(text_embeds, text_ids)
else:
return text_embeds

View File

@@ -103,20 +103,10 @@ class AudioPreprocessor:
return waveform
return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate)
@staticmethod
def normalize_amplitude(
waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5
) -> torch.Tensor:
waveform = waveform - waveform.mean(dim=2, keepdim=True)
peak = torch.max(torch.abs(waveform)) + eps
scale = peak.clamp(max=max_amplitude) / peak
return waveform * scale
def waveform_to_mel(
self, waveform: torch.Tensor, waveform_sample_rate: int, device
) -> torch.Tensor:
waveform = self.resample(waveform, waveform_sample_rate)
waveform = self.normalize_amplitude(waveform)
mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=self.target_sample_rate,
@@ -189,9 +179,12 @@ class AudioVAE(torch.nn.Module):
waveform = self.device_manager.move_to_load_device(waveform)
expected_channels = self.autoencoder.encoder.in_channels
if waveform.shape[1] != expected_channels:
raise ValueError(
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
)
if waveform.shape[1] == 1:
waveform = waveform.expand(-1, expected_channels, *waveform.shape[2:])
else:
raise ValueError(
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
)
mel_spec = self.preprocessor.waveform_to_mel(
waveform, waveform_sample_rate, device=self.device_manager.load_device

View File

@@ -1,11 +1,11 @@
from typing import Tuple, Union
import threading
import torch
import torch.nn as nn
import comfy.ops
ops = comfy.ops.disable_weight_init
class CausalConv3d(nn.Module):
def __init__(
self,
@@ -42,23 +42,34 @@ class CausalConv3d(nn.Module):
padding_mode=spatial_padding_mode,
groups=groups,
)
self.temporal_cache_state={}
def forward(self, x, causal: bool = True):
if causal:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.time_kernel_size - 1, 1, 1)
)
x = torch.concatenate((first_frame_pad, x), dim=2)
else:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
last_frame_pad = x[:, :, -1:, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
x = self.conv(x)
return x
tid = threading.get_ident()
cached, is_end = self.temporal_cache_state.get(tid, (None, False))
if cached is None:
padding_length = self.time_kernel_size - 1
if not causal:
padding_length = padding_length // 2
if x.shape[2] == 0:
return x
cached = x[:, :, :1, :, :].repeat((1, 1, padding_length, 1, 1))
pieces = [ cached, x ]
if is_end and not causal:
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
needs_caching = not is_end
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
x = torch.cat(pieces, dim=2)
if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
@property
def weight(self):

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
import threading
import torch
from torch import nn
from functools import partial
@@ -6,12 +7,35 @@ import math
from einops import rearrange
from typing import List, Optional, Tuple, Union
from .conv_nd_factory import make_conv_nd, make_linear_nd
from .causal_conv3d import CausalConv3d
from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init
def mark_conv3d_ended(module):
tid = threading.get_ident()
for _, m in module.named_modules():
if isinstance(m, CausalConv3d):
current = m.temporal_cache_state.get(tid, (None, False))
m.temporal_cache_state[tid] = (current[0], True)
def split2(tensor, split_point, dim=2):
return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim)
def add_exchange_cache(dest, cache_in, new_input, dim=2):
if dest is not None:
if cache_in is not None:
cache_to_dest = min(dest.shape[dim], cache_in.shape[dim])
lead_in_dest, dest = split2(dest, cache_to_dest, dim=dim)
lead_in_source, cache_in = split2(cache_in, cache_to_dest, dim=dim)
lead_in_dest.add_(lead_in_source)
body, new_input = split2(new_input, dest.shape[dim], dim)
dest.add_(body)
return torch_cat_if_needed([cache_in, new_input], dim=dim)
class Encoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
@@ -205,7 +229,7 @@ class Encoder(nn.Module):
self.gradient_checkpointing = False
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
@@ -254,6 +278,22 @@ class Encoder(nn.Module):
return sample
def forward(self, *args, **kwargs):
#No encoder support so just flag the end so it doesnt use the cache.
mark_conv3d_ended(self)
try:
return self.forward_orig(*args, **kwargs)
finally:
tid = threading.get_ident()
for _, module in self.named_modules():
# ComfyUI doesn't thread this kind of stuff today, but just in case
# we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
class Decoder(nn.Module):
r"""
@@ -341,18 +381,6 @@ class Decoder(nn.Module):
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "attn_res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
attention_head_dim=block_params["attention_head_dim"],
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2)
block = ResnetBlock3D(
@@ -428,8 +456,9 @@ class Decoder(nn.Module):
)
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
def forward(
def forward_orig(
self,
sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None,
@@ -437,6 +466,7 @@ class Decoder(nn.Module):
r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0]
mark_conv3d_ended(self.conv_in)
sample = self.conv_in(sample, causal=self.causal)
checkpoint_fn = (
@@ -445,24 +475,12 @@ class Decoder(nn.Module):
else lambda x: x
)
scaled_timestep = None
timestep_shift_scale = None
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
for up_block in self.up_blocks:
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
sample = self.conv_norm_out(sample)
if self.timestep_conditioning:
embedded_timestep = self.last_time_embedder(
timestep=scaled_timestep.flatten(),
resolution=None,
@@ -483,16 +501,62 @@ class Decoder(nn.Module):
embedded_timestep.shape[-2],
embedded_timestep.shape[-1],
)
shift, scale = ada_values.unbind(dim=1)
sample = sample * (1 + scale) + shift
timestep_shift_scale = ada_values.unbind(dim=1)
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
output = []
def run_up(idx, sample, ended):
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
output.append(sample)
return
up_block = self.up_blocks[idx]
if (ended):
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
if sample is None or sample.shape[2] == 0:
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
run_up(0, sample, True)
sample = torch.cat(output, dim=2)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample
def forward(self, *args, **kwargs):
try:
return self.forward_orig(*args, **kwargs)
finally:
for _, module in self.named_modules():
#ComfyUI doesn't thread this kind of stuff today, but just incase
#we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)
class UNetMidBlock3D(nn.Module):
"""
@@ -663,8 +727,22 @@ class DepthToSpaceUpsample(nn.Module):
)
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
self.temporal_cache_state = {}
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
tid = threading.get_ident()
cached, drop_first_conv, drop_first_res = self.temporal_cache_state.get(tid, (None, True, True))
y = self.conv(x, causal=causal)
y = rearrange(
y,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2 and y.shape[2] > 0 and drop_first_conv:
y = y[:, :, 1:, :, :]
drop_first_conv = False
if self.residual:
# Reshape and duplicate the input to match the output shape
x_in = rearrange(
@@ -676,21 +754,20 @@ class DepthToSpaceUpsample(nn.Module):
)
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
if self.stride[0] == 2:
if self.stride[0] == 2 and x_in.shape[2] > 0 and drop_first_res:
x_in = x_in[:, :, 1:, :, :]
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2:
x = x[:, :, 1:, :, :]
if self.residual:
x = x + x_in
return x
drop_first_res = False
if y.shape[2] == 0:
y = None
cached = add_exchange_cache(y, cached, x_in, dim=2)
self.temporal_cache_state[tid] = (cached, drop_first_conv, drop_first_res)
else:
self.temporal_cache_state[tid] = (None, drop_first_conv, False)
return y
class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None:
@@ -807,6 +884,8 @@ class ResnetBlock3D(nn.Module):
torch.randn(4, in_channels) / in_channels**0.5
)
self.temporal_cache_state={}
def _feed_spatial_noise(
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
) -> torch.FloatTensor:
@@ -880,9 +959,12 @@ class ResnetBlock3D(nn.Module):
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states
tid = threading.get_ident()
cached = self.temporal_cache_state.get(tid, None)
cached = add_exchange_cache(hidden_states, cached, input_tensor, dim=2)
self.temporal_cache_state[tid] = cached
return output_tensor
return hidden_states
def patchify(x, patch_size_hw, patch_size_t=1):

View File

@@ -13,10 +13,53 @@ from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.patcher_extension
import comfy.utils
def modulate(x, scale):
return x * (1 + scale.unsqueeze(1))
def invert_slices(slices, length):
sorted_slices = sorted(slices)
result = []
current = 0
for start, end in sorted_slices:
if current < start:
result.append((current, start))
current = max(current, end)
if current < length:
result.append((current, length))
return result
def modulate(x, scale, timestep_zero_index=None):
if timestep_zero_index is None:
return x * (1 + scale.unsqueeze(1))
else:
scale = (1 + scale.unsqueeze(1))
actual_batch = scale.size(0) // 2
slices = timestep_zero_index
invert = invert_slices(timestep_zero_index, x.shape[1])
for s in slices:
x[:, s[0]:s[1]] *= scale[actual_batch:]
for s in invert:
x[:, s[0]:s[1]] *= scale[:actual_batch]
return x
def apply_gate(gate, x, timestep_zero_index=None):
if timestep_zero_index is None:
return gate * x
else:
actual_batch = gate.size(0) // 2
slices = timestep_zero_index
invert = invert_slices(timestep_zero_index, x.shape[1])
for s in slices:
x[:, s[0]:s[1]] *= gate[actual_batch:]
for s in invert:
x[:, s[0]:s[1]] *= gate[:actual_batch]
return x
#############################################################################
# Core NextDiT Model #
@@ -258,6 +301,7 @@ class JointTransformerBlock(nn.Module):
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor]=None,
timestep_zero_index=None,
transformer_options={},
):
"""
@@ -276,18 +320,18 @@ class JointTransformerBlock(nn.Module):
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
x = x + apply_gate(gate_msa.unsqueeze(1).tanh(), self.attention_norm2(
clamp_fp16(self.attention(
modulate(self.attention_norm1(x), scale_msa),
modulate(self.attention_norm1(x), scale_msa, timestep_zero_index=timestep_zero_index),
x_mask,
freqs_cis,
transformer_options=transformer_options,
))
))), timestep_zero_index=timestep_zero_index
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
x = x + apply_gate(gate_mlp.unsqueeze(1).tanh(), self.ffn_norm2(
clamp_fp16(self.feed_forward(
modulate(self.ffn_norm1(x), scale_mlp),
))
modulate(self.ffn_norm1(x), scale_mlp, timestep_zero_index=timestep_zero_index),
))), timestep_zero_index=timestep_zero_index
)
else:
assert adaln_input is None
@@ -345,13 +389,37 @@ class FinalLayer(nn.Module):
),
)
def forward(self, x, c):
def forward(self, x, c, timestep_zero_index=None):
scale = self.adaLN_modulation(c)
x = modulate(self.norm_final(x), scale)
x = modulate(self.norm_final(x), scale, timestep_zero_index=timestep_zero_index)
x = self.linear(x)
return x
def pad_zimage(feats, pad_token, pad_tokens_multiple):
pad_extra = (-feats.shape[1]) % pad_tokens_multiple
return torch.cat((feats, pad_token.to(device=feats.device, dtype=feats.dtype, copy=True).unsqueeze(0).repeat(feats.shape[0], pad_extra, 1)), dim=1), pad_extra
def pos_ids_x(start_t, H_tokens, W_tokens, batch_size, device, transformer_options={}):
rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
x_pos_ids = torch.zeros((batch_size, H_tokens * W_tokens, 3), dtype=torch.float32, device=device)
x_pos_ids[:, :, 0] = start_t
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
return x_pos_ids
class NextDiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
@@ -378,6 +446,7 @@ class NextDiT(nn.Module):
time_scale=1.0,
pad_tokens_multiple=None,
clip_text_dim=None,
siglip_feat_dim=None,
image_model=None,
device=None,
dtype=None,
@@ -491,6 +560,41 @@ class NextDiT(nn.Module):
for layer_id in range(n_layers)
]
)
if siglip_feat_dim is not None:
self.siglip_embedder = nn.Sequential(
operation_settings.get("operations").RMSNorm(siglip_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear(
siglip_feat_dim,
dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.siglip_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=False,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
self.siglip_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
else:
self.siglip_embedder = None
self.siglip_refiner = None
self.siglip_pad_token = None
# This norm final is in the lumina 2.0 code but isn't actually used for anything.
# self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
@@ -531,70 +635,168 @@ class NextDiT(nn.Module):
imgs = torch.stack(imgs, dim=0)
return imgs
def patchify_and_embed(
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
bsz = len(x)
pH = pW = self.patch_size
device = x[0].device
orig_x = x
if self.pad_tokens_multiple is not None:
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
def embed_cap(self, cap_feats=None, offset=0, bsz=1, device=None, dtype=None):
if cap_feats is not None:
cap_feats = self.cap_embedder(cap_feats)
cap_feats_len = cap_feats.shape[1]
if self.pad_tokens_multiple is not None:
cap_feats, _ = pad_zimage(cap_feats, self.cap_pad_token, self.pad_tokens_multiple)
else:
cap_feats_len = 0
cap_feats = self.cap_pad_token.to(device=device, dtype=dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + offset
embeds = (cap_feats,)
freqs_cis = (self.rope_embedder(cap_pos_ids).movedim(1, 2),)
return embeds, freqs_cis, cap_feats_len
def embed_all(self, x, cap_feats=None, siglip_feats=None, offset=0, omni=False, transformer_options={}):
bsz = 1
pH = pW = self.patch_size
device = x.device
embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype)
if (not omni) or self.siglip_embedder is None:
cap_feats_len = embeds[0].shape[1] + offset
embeds += (None,)
freqs_cis += (None,)
else:
cap_feats_len += offset
if siglip_feats is not None:
b, h, w, c = siglip_feats.shape
siglip_feats = siglip_feats.permute(0, 3, 1, 2).reshape(b, h * w, c)
siglip_feats = self.siglip_embedder(siglip_feats)
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
siglip_pos_ids[:, :, 0] = cap_feats_len + 2
siglip_pos_ids[:, :, 1] = (torch.linspace(0, h * 8 - 1, steps=h, dtype=torch.float32, device=device).floor()).view(-1, 1).repeat(1, w).flatten()
siglip_pos_ids[:, :, 2] = (torch.linspace(0, w * 8 - 1, steps=w, dtype=torch.float32, device=device).floor()).view(1, -1).repeat(h, 1).flatten()
if self.siglip_pad_token is not None:
siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check
siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra))
else:
if self.siglip_pad_token is not None:
siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
if siglip_feats is None:
embeds += (None,)
freqs_cis += (None,)
else:
embeds += (siglip_feats,)
freqs_cis += (self.rope_embedder(siglip_pos_ids).movedim(1, 2),)
B, C, H, W = x.shape
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
H_tokens, W_tokens = H // pH, W // pW
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
x_pos_ids = pos_ids_x(cap_feats_len + 1, H // pH, W // pW, bsz, device, transformer_options=transformer_options)
if self.pad_tokens_multiple is not None:
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
x, pad_extra = pad_zimage(x, self.x_pad_token, self.pad_tokens_multiple)
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
embeds += (x,)
freqs_cis += (self.rope_embedder(x_pos_ids).movedim(1, 2),)
return embeds, freqs_cis, cap_feats_len + len(freqs_cis) - 1
def patchify_and_embed(
self, x: torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
bsz = x.shape[0]
cap_mask = None # TODO?
main_siglip = None
orig_x = x
embeds = ([], [], [])
freqs_cis = ([], [], [])
leftover_cap = []
start_t = 0
omni = len(ref_latents) > 0
if omni:
for i, ref in enumerate(ref_latents):
if i < len(ref_contexts):
ref_con = ref_contexts[i]
else:
ref_con = None
if i < len(siglip_feats):
sig_feat = siglip_feats[i]
else:
sig_feat = None
out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options)
for i, e in enumerate(out[0]):
if e is not None:
embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz))
freqs_cis[i].append(out[1][i])
start_t = out[2]
leftover_cap = ref_contexts[len(ref_latents):]
H, W = x.shape[-2], x.shape[-1]
img_sizes = [(H, W)] * bsz
out = self.embed_all(x, cap_feats, main_siglip, offset=start_t, omni=omni, transformer_options=transformer_options)
img_len = out[0][-1].shape[1]
cap_len = out[0][0].shape[1]
for i, e in enumerate(out[0]):
if e is not None:
e = comfy.utils.repeat_to_batch_size(e, bsz)
embeds[i].append(e)
freqs_cis[i].append(out[1][i])
start_t = out[2]
for cap in leftover_cap:
out = self.embed_cap(cap, offset=start_t, bsz=bsz, device=x.device, dtype=x.dtype)
cap_len += out[0][0].shape[1]
embeds[0].append(comfy.utils.repeat_to_batch_size(out[0][0], bsz))
freqs_cis[0].append(out[1][0])
start_t += out[2]
patches = transformer_options.get("patches", {})
# refine context
cap_feats = torch.cat(embeds[0], dim=1)
cap_freqs_cis = torch.cat(freqs_cis[0], dim=1)
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
feats = (cap_feats,)
fc = (cap_freqs_cis,)
if omni and len(embeds[1]) > 0:
siglip_mask = None
siglip_feats_combined = torch.cat(embeds[1], dim=1)
siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1)
if self.siglip_refiner is not None:
for layer in self.siglip_refiner:
siglip_feats_combined = layer(siglip_feats_combined, siglip_mask, siglip_feats_freqs_cis, transformer_options=transformer_options)
feats += (siglip_feats_combined,)
fc += (siglip_feats_freqs_cis,)
padded_img_mask = None
x = torch.cat(embeds[-1], dim=1)
fc_x = torch.cat(freqs_cis[-1], dim=1)
if omni:
timestep_zero_index = [(x.shape[1] - img_len, x.shape[1])]
else:
timestep_zero_index = None
x_input = x
for i, layer in enumerate(self.noise_refiner):
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
x = layer(x, padded_img_mask, fc_x, t, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
if "noise_refiner" in patches:
for p in patches["noise_refiner"]:
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": fc_x, "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
if "img" in out:
x = out["img"]
padded_full_embed = torch.cat((cap_feats, x), dim=1)
padded_full_embed = torch.cat(feats + (x,), dim=1)
if timestep_zero_index is not None:
ind = padded_full_embed.shape[1] - x.shape[1]
timestep_zero_index = [(ind + x.shape[1] - img_len, ind + x.shape[1])]
timestep_zero_index.append((feats[0].shape[1] - cap_len, feats[0].shape[1]))
mask = None
img_sizes = [(H, W)] * bsz
l_effective_cap_len = [cap_feats.shape[1]] * bsz
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
l_effective_cap_len = [padded_full_embed.shape[1] - img_len] * bsz
return padded_full_embed, mask, img_sizes, l_effective_cap_len, torch.cat(fc + (fc_x,), dim=1), timestep_zero_index
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@@ -604,7 +806,11 @@ class NextDiT(nn.Module):
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
# def forward(self, x, t, cap_feats, cap_mask):
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs):
omni = len(ref_latents) > 0
if omni:
timesteps = torch.cat([timesteps * 0, timesteps], dim=0)
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
@@ -619,8 +825,6 @@ class NextDiT(nn.Module):
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
adaln_input = t
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
if self.clip_text_pooled_proj is not None:
pooled = kwargs.get("clip_text_pooled", None)
if pooled is not None:
@@ -632,7 +836,7 @@ class NextDiT(nn.Module):
patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, ref_latents=ref_latents, ref_contexts=ref_contexts, siglip_feats=siglip_feats, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(img.device)
transformer_options["total_blocks"] = len(self.layers)
@@ -640,7 +844,7 @@ class NextDiT(nn.Module):
img_input = img
for i, layer in enumerate(self.layers):
transformer_options["block_index"] = i
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
@@ -649,8 +853,7 @@ class NextDiT(nn.Module):
if "txt" in out:
img[:, :cap_size[0]] = out["txt"]
img = self.final_layer(img, adaln_input)
img = self.final_layer(img, adaln_input, timestep_zero_index=timestep_zero_index)
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
return -img

View File

@@ -14,10 +14,13 @@ if model_management.xformers_enabled_vae():
import xformers.ops
def torch_cat_if_needed(xl, dim):
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
if len(xl) > 1:
return torch.cat(xl, dim)
else:
elif len(xl) == 1:
return xl[0]
else:
return None
def get_timestep_embedding(timesteps, embedding_dim):
"""

View File

@@ -170,8 +170,14 @@ class Attention(nn.Module):
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
if encoder_hidden_states_mask is not None:
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
else:
attn_mask = None
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attention_mask, transformer_options=transformer_options,
attn_mask, transformer_options=transformer_options,
skip_reshape=True)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
@@ -430,6 +436,9 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states = context
encoder_hidden_states_mask = attention_mask
if encoder_hidden_states_mask is not None and not torch.is_floating_point(encoder_hidden_states_mask):
encoder_hidden_states_mask = (encoder_hidden_states_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1]

View File

@@ -62,6 +62,8 @@ class WanSelfAttention(nn.Module):
x(Tensor): Shape [B, L, num_heads, C / num_heads]
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
patches = transformer_options.get("patches", {})
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
def qkv_fn_q(x):
@@ -86,6 +88,10 @@ class WanSelfAttention(nn.Module):
transformer_options=transformer_options,
)
if "attn1_patch" in patches:
for p in patches["attn1_patch"]:
x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
x = self.o(x)
return x
@@ -225,6 +231,8 @@ class WanAttentionBlock(nn.Module):
"""
# assert e.dtype == torch.float32
patches = transformer_options.get("patches", {})
if e.ndim < 4:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
else:
@@ -242,6 +250,11 @@ class WanAttentionBlock(nn.Module):
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
if "attn2_patch" in patches:
for p in patches["attn2_patch"]:
x = p({"x": x, "transformer_options": transformer_options})
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
return x
@@ -488,7 +501,7 @@ class WanModel(torch.nn.Module):
self.blocks = nn.ModuleList([
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
for _ in range(num_layers)
for i in range(num_layers)
])
# head
@@ -541,6 +554,7 @@ class WanModel(torch.nn.Module):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
transformer_options["grid_sizes"] = grid_sizes
x = x.flatten(2).transpose(1, 2)
# time embeddings
@@ -738,6 +752,7 @@ class VaceWanModel(WanModel):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
transformer_options["grid_sizes"] = grid_sizes
x = x.flatten(2).transpose(1, 2)
# time embeddings

View File

@@ -0,0 +1,500 @@
import torch
from einops import rearrange, repeat
import comfy
from comfy.ldm.modules.attention import optimized_attention
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, split_num=8):
scale = 1.0 / visual_q.shape[-1] ** 0.5
visual_q = visual_q.transpose(1, 2) * scale
B, H, x_seqlens, K = visual_q.shape
x_ref_attn_maps = []
for class_idx, ref_target_mask in enumerate(ref_target_masks):
ref_target_mask = ref_target_mask.view(1, 1, 1, -1)
x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype)
chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens)
for i in range(0, x_seqlens, chunk_size):
end_i = min(i + chunk_size, x_seqlens)
attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens
# Apply softmax
attn_max = attn_chunk.max(dim=-1, keepdim=True).values
attn_chunk = (attn_chunk - attn_max).exp()
attn_sum = attn_chunk.sum(dim=-1, keepdim=True)
attn_chunk = attn_chunk / (attn_sum + 1e-8)
# Apply mask and sum
masked_attn = attn_chunk * ref_target_mask
x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8)
del attn_chunk, masked_attn
# Average across heads
x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens
x_ref_attn_maps.append(x_ref_attnmap)
del visual_q, ref_k
return torch.cat(x_ref_attn_maps, dim=0)
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2):
"""Args:
query (torch.tensor): B M H K
key (torch.tensor): B M H K
shape (tuple): (N_t, N_h, N_w)
ref_target_masks: [B, N_h * N_w]
"""
N_t, N_h, N_w = shape
x_seqlens = N_h * N_w
ref_k = ref_k[:, :x_seqlens]
_, seq_lens, heads, _ = visual_q.shape
class_num, _ = ref_target_masks.shape
x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q)
split_chunk = heads // split_num
for i in range(split_num):
x_ref_attn_maps_perhead = calculate_x_ref_attn_map(
visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :],
ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :],
ref_target_masks
)
x_ref_attn_maps += x_ref_attn_maps_perhead
return x_ref_attn_maps / split_num
def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
source_min, source_max = source_range
new_min, new_max = target_range
normalized = (column - source_min) / (source_max - source_min + epsilon)
scaled = normalized * (new_max - new_min) + new_min
return scaled
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
def get_audio_embeds(encoded_audio, audio_start, audio_end):
audio_embs = []
human_num = len(encoded_audio)
audio_frames = encoded_audio[0].shape[0]
indices = (torch.arange(4 + 1) - 2) * 1
for human_idx in range(human_num):
if audio_end > audio_frames: # in case of not enough audio for current window, pad with first audio frame as that's most likely silence
pad_len = audio_end - audio_frames
pad_shape = list(encoded_audio[human_idx].shape)
pad_shape[0] = pad_len
pad_tensor = encoded_audio[human_idx][:1].repeat(pad_len, *([1] * (encoded_audio[human_idx].dim() - 1)))
encoded_audio_in = torch.cat([encoded_audio[human_idx], pad_tensor], dim=0)
else:
encoded_audio_in = encoded_audio[human_idx]
center_indices = torch.arange(audio_start, audio_end, 1).unsqueeze(1) + indices.unsqueeze(0)
center_indices = torch.clamp(center_indices, min=0, max=encoded_audio_in.shape[0] - 1)
audio_emb = encoded_audio_in[center_indices].unsqueeze(0)
audio_embs.append(audio_emb)
return torch.cat(audio_embs, dim=0)
def project_audio_features(audio_proj, encoded_audio, audio_start, audio_end):
audio_embs = get_audio_embeds(encoded_audio, audio_start, audio_end)
first_frame_audio_emb_s = audio_embs[:, :1, ...]
latter_frame_audio_emb = audio_embs[:, 1:, ...]
latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=4)
middle_index = audio_proj.seq_len // 2
latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...]
latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...]
latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...]
latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
latter_frame_audio_emb_s = torch.cat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2)
audio_emb = audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s)
audio_emb = torch.cat(audio_emb.split(1), dim=2)
return audio_emb
class RotaryPositionalEmbedding1D(torch.nn.Module):
def __init__(self,
head_dim,
):
super().__init__()
self.head_dim = head_dim
self.base = 10000
def precompute_freqs_cis_1d(self, pos_indices):
freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
freqs = freqs.to(pos_indices.device)
freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs)
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
return freqs
def forward(self, x, pos_indices):
freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
x_ = x.float()
freqs_cis = freqs_cis.float().to(x.device)
cos, sin = freqs_cis.cos(), freqs_cis.sin()
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
x_ = (x_ * cos) + (rotate_half(x_) * sin)
return x_.type_as(x)
class SingleStreamAttention(torch.nn.Module):
def __init__(
self,
dim: int,
encoder_hidden_states_dim: int,
num_heads: int,
qkv_bias: bool,
device=None, dtype=None, operations=None
) -> None:
super().__init__()
self.dim = dim
self.encoder_hidden_states_dim = encoder_hidden_states_dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_linear = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
self.kv_linear = operations.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None) -> torch.Tensor:
N_t, N_h, N_w = shape
expected_tokens = N_t * N_h * N_w
actual_tokens = x.shape[1]
x_extra = None
if actual_tokens != expected_tokens:
x_extra = x[:, -N_h * N_w:, :]
x = x[:, :-N_h * N_w, :]
N_t = N_t - 1
B = x.shape[0]
S = N_h * N_w
x = x.view(B * N_t, S, self.dim)
# get q for hidden_state
q = self.q_linear(x).view(B * N_t, S, self.num_heads, self.head_dim)
# get kv from encoder_hidden_states # shape: (B, N, num_heads, head_dim)
kv = self.kv_linear(encoder_hidden_states)
encoder_k, encoder_v = kv.view(B * N_t, encoder_hidden_states.shape[1], 2, self.num_heads, self.head_dim).unbind(2)
#print("q.shape", q.shape) #torch.Size([21, 1024, 40, 128])
x = optimized_attention(
q.transpose(1, 2),
encoder_k.transpose(1, 2),
encoder_v.transpose(1, 2),
heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2)
# linear transform
x = self.proj(x.reshape(B * N_t, S, self.dim))
x = x.view(B, N_t * S, self.dim)
if x_extra is not None:
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
return x
class SingleStreamMultiAttention(SingleStreamAttention):
def __init__(
self,
dim: int,
encoder_hidden_states_dim: int,
num_heads: int,
qkv_bias: bool,
class_range: int = 24,
class_interval: int = 4,
device=None, dtype=None, operations=None
) -> None:
super().__init__(
dim=dim,
encoder_hidden_states_dim=encoder_hidden_states_dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
device=device,
dtype=dtype,
operations=operations
)
# Rotary-embedding layout parameters
self.class_interval = class_interval
self.class_range = class_range
self.max_humans = self.class_range // self.class_interval
# Constant bucket used for background tokens
self.rope_bak = int(self.class_range // 2)
self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
def forward(
self,
x: torch.Tensor,
encoder_hidden_states: torch.Tensor,
shape=None,
x_ref_attn_map=None
) -> torch.Tensor:
encoder_hidden_states = encoder_hidden_states.squeeze(0).to(x.device)
human_num = x_ref_attn_map.shape[0] if x_ref_attn_map is not None else 1
# Single-speaker fall-through
if human_num <= 1:
return super().forward(x, encoder_hidden_states, shape)
N_t, N_h, N_w = shape
x_extra = None
if x.shape[0] * N_t != encoder_hidden_states.shape[0]:
x_extra = x[:, -N_h * N_w:, :]
x = x[:, :-N_h * N_w, :]
N_t = N_t - 1
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
# Query projection
B, N, C = x.shape
q = self.q_linear(x)
q = q.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
# Use `class_range` logic for 2 speakers
rope_h1 = (0, self.class_interval)
rope_h2 = (self.class_range - self.class_interval, self.class_range)
rope_bak = int(self.class_range // 2)
# Normalize and scale attention maps for each speaker
max_values = x_ref_attn_map.max(1).values[:, None, None]
min_values = x_ref_attn_map.min(1).values[:, None, None]
max_min_values = torch.cat([max_values, min_values], dim=2)
human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min()
human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min()
human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), rope_h1)
human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), rope_h2)
back = torch.full((x_ref_attn_map.size(1),), rope_bak, dtype=human1.dtype, device=human1.device)
# Token-wise speaker dominance
max_indices = x_ref_attn_map.argmax(dim=0)
normalized_map = torch.stack([human1, human2, back], dim=1)
normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices]
# Apply rotary to Q
q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
q = self.rope_1d(q, normalized_pos)
q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
# Keys / Values
_, N_a, _ = encoder_hidden_states.shape
encoder_kv = self.kv_linear(encoder_hidden_states)
encoder_kv = encoder_kv.view(B, N_a, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
encoder_k, encoder_v = encoder_kv.unbind(0)
# Rotary for keys assign centre of each speaker bucket to its context tokens
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device)
per_frame[: per_frame.size(0) // 2] = (rope_h1[0] + rope_h1[1]) / 2
per_frame[per_frame.size(0) // 2 :] = (rope_h2[0] + rope_h2[1]) / 2
encoder_pos = torch.cat([per_frame] * N_t, dim=0)
encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
encoder_k = self.rope_1d(encoder_k, encoder_pos)
encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
# Final attention
q = rearrange(q, "B H M K -> B M H K")
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
x = optimized_attention(
q.transpose(1, 2),
encoder_k.transpose(1, 2),
encoder_v.transpose(1, 2),
heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2)
# Linear projection
x = x.reshape(B, N, C)
x = self.proj(x)
# Restore original layout
x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
if x_extra is not None:
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
return x
class MultiTalkAudioProjModel(torch.nn.Module):
def __init__(
self,
seq_len: int = 5,
seq_len_vf: int = 12,
blocks: int = 12,
channels: int = 768,
intermediate_dim: int = 512,
out_dim: int = 768,
context_tokens: int = 32,
device=None, dtype=None, operations=None
):
super().__init__()
self.seq_len = seq_len
self.blocks = blocks
self.channels = channels
self.input_dim = seq_len * blocks * channels
self.input_dim_vf = seq_len_vf * blocks * channels
self.intermediate_dim = intermediate_dim
self.context_tokens = context_tokens
self.out_dim = out_dim
# define multiple linear layers
self.proj1 = operations.Linear(self.input_dim, intermediate_dim, device=device, dtype=dtype)
self.proj1_vf = operations.Linear(self.input_dim_vf, intermediate_dim, device=device, dtype=dtype)
self.proj2 = operations.Linear(intermediate_dim, intermediate_dim, device=device, dtype=dtype)
self.proj3 = operations.Linear(intermediate_dim, context_tokens * out_dim, device=device, dtype=dtype)
self.norm = operations.LayerNorm(out_dim, device=device, dtype=dtype)
def forward(self, audio_embeds, audio_embeds_vf):
video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
B, _, _, S, C = audio_embeds.shape
# process audio of first frame
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
batch_size, window_size, blocks, channels = audio_embeds.shape
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
# process audio of latter frame
audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
# first projection
audio_embeds = torch.relu(self.proj1(audio_embeds))
audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf))
audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1)
batch_size_c, N_t, C_a = audio_embeds_c.shape
audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
# second projection
audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.out_dim)
# normalization and reshape
context_tokens = self.norm(context_tokens)
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
return context_tokens
class WanMultiTalkAttentionBlock(torch.nn.Module):
def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None):
super().__init__()
self.audio_cross_attn = SingleStreamMultiAttention(in_dim, out_dim, num_heads=40, qkv_bias=True, device=device, dtype=dtype, operations=operations)
self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True)
class MultiTalkGetAttnMapPatch:
def __init__(self, ref_target_masks=None):
self.ref_target_masks = ref_target_masks
def __call__(self, kwargs):
transformer_options = kwargs.get("transformer_options", {})
x = kwargs["x"]
if self.ref_target_masks is not None:
x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device))
transformer_options["x_ref_attn_map"] = x_ref_attn_map
return x
class MultiTalkCrossAttnPatch:
def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None):
self.model_patch = model_patch
self.audio_scale = audio_scale
self.ref_target_masks = ref_target_masks
def __call__(self, kwargs):
transformer_options = kwargs.get("transformer_options", {})
block_idx = transformer_options.get("block_index", None)
x = kwargs["x"]
if block_idx is None:
return torch.zeros_like(x)
audio_embeds = transformer_options.get("audio_embeds")
x_ref_attn_map = transformer_options.pop("x_ref_attn_map", None)
norm_x = self.model_patch.model.blocks[block_idx].norm_x(x)
x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn(
norm_x, audio_embeds.to(x.dtype),
shape=transformer_options["grid_sizes"],
x_ref_attn_map=x_ref_attn_map
)
x = x + x_audio * self.audio_scale
return x
def models(self):
return [self.model_patch]
class MultiTalkApplyModelWrapper:
def __init__(self, init_latents):
self.init_latents = init_latents
def __call__(self, executor, x, *args, **kwargs):
x[:, :, :self.init_latents.shape[2]] = self.init_latents.to(x)
samples = executor(x, *args, **kwargs)
return samples
class InfiniteTalkOuterSampleWrapper:
def __init__(self, motion_frames_latent, model_patch, is_extend=False):
self.motion_frames_latent = motion_frames_latent
self.model_patch = model_patch
self.is_extend = is_extend
def __call__(self, executor, *args, **kwargs):
model_patcher = executor.class_obj.model_patcher
model_options = executor.class_obj.model_options
process_latent_in = model_patcher.model.process_latent_in
# for InfiniteTalk, model input first latent(s) need to always be replaced on every step
if self.motion_frames_latent is not None:
wrappers = model_options["transformer_options"]["wrappers"]
w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {})
w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(self.motion_frames_latent))]
# run the sampling process
result = executor(*args, **kwargs)
# insert motion frames before decoding
if self.is_extend:
overlap = self.motion_frames_latent.shape[2]
result = torch.cat([self.motion_frames_latent.to(result), result[:, :, overlap:]], dim=2)
return result
def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
if self.motion_frames_latent is not None:
self.motion_frames_latent = self.motion_frames_latent.to(device_or_dtype)
return self

View File

@@ -49,6 +49,7 @@ import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model
import comfy.model_management
import comfy.patcher_extension
@@ -1147,9 +1148,31 @@ class CosmosPredict2(BaseModel):
sigma = (sigma / (sigma + 1))
return latent_image / (1.0 - sigma)
class Anima(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.anima.model.Anima)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
t5xxl_ids = kwargs.get("t5xxl_ids", None)
t5xxl_weights = kwargs.get("t5xxl_weights", None)
device = kwargs["device"]
if cross_attn is not None:
if t5xxl_ids is not None:
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
if t5xxl_weights is not None:
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
if cross_attn.shape[1] < 512:
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class Lumina2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
self.memory_usage_factor_conds = ("ref_latents",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
@@ -1169,6 +1192,35 @@ class Lumina2(BaseModel):
if clip_text_pooled is not None:
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
clip_vision_outputs = kwargs.get("clip_vision_outputs", list(map(lambda a: a.get("clip_vision_output"), kwargs.get("unclip_conditioning", [{}])))) # Z Image omni
if clip_vision_outputs is not None and len(clip_vision_outputs) > 0:
sigfeats = []
for clip_vision_output in clip_vision_outputs:
if clip_vision_output is not None:
image_size = clip_vision_output.image_sizes[0]
shape = clip_vision_output.last_hidden_state.shape
sigfeats.append(clip_vision_output.last_hidden_state.reshape(shape[0], image_size[1] // 16, image_size[2] // 16, shape[-1]))
if len(sigfeats) > 0:
out['siglip_feats'] = comfy.conds.CONDList(sigfeats)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
ref_contexts = kwargs.get("reference_latents_text_embeds", None)
if ref_contexts is not None:
out['ref_contexts'] = comfy.conds.CONDList(ref_contexts)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
return out
class WAN21(BaseModel):
@@ -1526,6 +1578,9 @@ class QwenImage(BaseModel):
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)

View File

@@ -253,7 +253,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "chroma_radiance"
dit_config["in_channels"] = 3
dit_config["out_channels"] = 3
dit_config["patch_size"] = 16
dit_config["patch_size"] = state_dict.get('{}img_in_patch.weight'.format(key_prefix)).size(dim=-1)
dit_config["nerf_hidden_size"] = 64
dit_config["nerf_mlp_ratio"] = 4
dit_config["nerf_depth"] = 4
@@ -446,6 +446,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["time_scale"] = 1000.0
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
dit_config["pad_tokens_multiple"] = 32
sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None)
if sig_weight is not None:
dit_config["siglip_feat_dim"] = sig_weight.shape[0]
return dit_config
@@ -547,6 +550,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2
dit_config = {}
dit_config["image_model"] = "cosmos_predict2"
if "{}llm_adapter.blocks.0.cross_attn.q_proj.weight".format(key_prefix) in state_dict_keys:
dit_config["image_model"] = "anima"
dit_config["max_img_h"] = 240
dit_config["max_img_w"] = 240
dit_config["max_frames"] = 128

View File

@@ -57,6 +57,7 @@ import comfy.text_encoders.ovis
import comfy.text_encoders.kandinsky5
import comfy.text_encoders.jina_clip_2
import comfy.text_encoders.newbie
import comfy.text_encoders.anima
import comfy.model_patcher
import comfy.lora
@@ -635,14 +636,13 @@ class VAE:
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
if self.latent_channels == 48: # Wan 2.2
if self.latent_channels in [48, 128]: # Wan 2.2 and LTX2
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.process_input = self.process_output = lambda image: image
self.process_output = lambda image: image
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
else:
if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
@@ -1048,6 +1048,7 @@ class TEModel(Enum):
GEMMA_3_12B = 18
JINA_CLIP_2 = 19
QWEN3_8B = 20
QWEN3_06B = 21
def detect_te_model(sd):
@@ -1093,6 +1094,8 @@ def detect_te_model(sd):
return TEModel.QWEN3_2B
elif weight.shape[0] == 4096:
return TEModel.QWEN3_8B
elif weight.shape[0] == 1024:
return TEModel.QWEN3_06B
if weight.shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd:
return TEModel.MISTRAL3_24B
@@ -1233,6 +1236,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.JINA_CLIP_2:
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
elif te_model == TEModel.QWEN3_06B:
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:

View File

@@ -23,6 +23,7 @@ import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.kandinsky5
import comfy.text_encoders.z_image
import comfy.text_encoders.anima
from . import supported_models_base
from . import latent_formats
@@ -770,10 +771,24 @@ class Flux2(Flux):
return out
def clip_target(self, state_dict={}):
return None # TODO
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
if len(detect) > 0:
detect["model_type"] = "qwen3_4b"
return supported_models_base.ClipTarget(comfy.text_encoders.flux.KleinTokenizer, comfy.text_encoders.flux.klein_te(**detect))
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_8b.transformer.".format(pref))
if len(detect) > 0:
detect["model_type"] = "qwen3_8b"
return supported_models_base.ClipTarget(comfy.text_encoders.flux.KleinTokenizer8B, comfy.text_encoders.flux.klein_te(**detect))
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}mistral3_24b.transformer.".format(pref))
if len(detect) > 0:
if "{}mistral3_24b.transformer.model.layers.39.post_attention_layernorm.weight".format(pref) not in state_dict:
detect["pruned"] = True
return supported_models_base.ClipTarget(comfy.text_encoders.flux.Flux2Tokenizer, comfy.text_encoders.flux.flux2_te(**detect))
return None
class GenmoMochi(supported_models_base.BASE):
unet_config = {
@@ -992,6 +1007,36 @@ class CosmosT2IPredict2(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
class Anima(supported_models_base.BASE):
unet_config = {
"image_model": "anima",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 3.0,
}
unet_extra_config = {}
latent_format = latent_formats.Wan21
memory_usage_factor = 1.0
supported_inference_dtypes = [torch.bfloat16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Anima(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect))
class CosmosI2VPredict2(CosmosT2IPredict2):
unet_config = {
"image_model": "cosmos_predict2",
@@ -1551,6 +1596,6 @@ class Kandinsky5Image(Kandinsky5):
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models += [SVD_img2vid]

View File

@@ -112,7 +112,8 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
class TAEHV(nn.Module):
def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True):
def __init__(self, latent_channels, parallel=False, encoder_time_downscale=(True, True, False), decoder_time_upscale=(False, True, True), decoder_space_upscale=(True, True, True),
latent_format=None, show_progress_bar=False):
super().__init__()
self.image_channels = 3
self.patch_size = 1
@@ -124,6 +125,9 @@ class TAEHV(nn.Module):
self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
self.patch_size = 2
elif self.latent_channels == 128: # LTX2
self.patch_size, self.latent_channels, encoder_time_downscale, decoder_time_upscale = 4, 128, (True, True, True), (True, True, True)
if self.latent_channels == 32: # HunyuanVideo1.5
act_func = nn.LeakyReLU(0.2, inplace=True)
else: # HunyuanVideo, Wan 2.1
@@ -131,41 +135,52 @@ class TAEHV(nn.Module):
self.encoder = nn.Sequential(
conv(self.image_channels*self.patch_size**2, 64), act_func,
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 2 if encoder_time_downscale[0] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 2 if encoder_time_downscale[1] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 2 if encoder_time_downscale[2] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
conv(64, self.latent_channels),
)
n_f = [256, 128, 64, 64]
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
self.decoder = nn.Sequential(
Clamp(), conv(self.latent_channels, n_f[0]), act_func,
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 2 if decoder_time_upscale[0] else 1), conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[1] else 1), conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[2] else 1), conv(n_f[2], n_f[3], bias=False),
act_func, conv(n_f[3], self.image_channels*self.patch_size**2),
)
@property
def show_progress_bar(self):
return self._show_progress_bar
@show_progress_bar.setter
def show_progress_bar(self, value):
self._show_progress_bar = value
self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool))
self.t_upscale = 2**sum(t.stride == 2 for t in self.decoder if isinstance(t, TGrow))
self.frames_to_trim = self.t_upscale - 1
self._show_progress_bar = show_progress_bar
@property
def show_progress_bar(self):
return self._show_progress_bar
@show_progress_bar.setter
def show_progress_bar(self, value):
self._show_progress_bar = value
def encode(self, x, **kwargs):
if self.patch_size > 1:
x = F.pixel_unshuffle(x, self.patch_size)
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
if x.shape[1] % 4 != 0:
# pad at end to multiple of 4
n_pad = 4 - x.shape[1] % 4
if self.patch_size > 1:
B, T, C, H, W = x.shape
x = x.reshape(B * T, C, H, W)
x = F.pixel_unshuffle(x, self.patch_size)
x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size)
if x.shape[1] % self.t_downscale != 0:
# pad at end to multiple of t_downscale
n_pad = self.t_downscale - x.shape[1] % self.t_downscale
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
x = torch.cat([x, padding], 1)
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
return self.process_out(x)
def decode(self, x, **kwargs):
x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W]
x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W]
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
if self.patch_size > 1:

View File

@@ -0,0 +1,61 @@
from transformers import Qwen2Tokenizer, T5TokenizerFast
import comfy.text_encoders.llama
from comfy import sd1_clip
import os
import torch
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data)
class AnimaTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.qwen3_06b = Qwen3Tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs)
out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):
return self.t5xxl.untokenize(token_weight_pair)
def state_dict(self):
return {}
class Qwen3_06BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class AnimaTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3_06b", clip_model=Qwen3_06BModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
out = super().encode_token_weights(token_weight_pairs)
out[2]["t5xxl_ids"] = torch.tensor(list(map(lambda a: a[0], token_weight_pairs["t5xxl"][0])), dtype=torch.int)
out[2]["t5xxl_weights"] = torch.tensor(list(map(lambda a: a[1], token_weight_pairs["t5xxl"][0])))
return out
def te(dtype_llama=None, llama_quantization_metadata=None):
class AnimaTEModel_(AnimaTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return AnimaTEModel_

View File

@@ -10,9 +10,11 @@ import comfy.utils
def llama_detect(state_dict, prefix=""):
out = {}
t5_key = "{}model.norm.weight".format(prefix)
if t5_key in state_dict:
out["dtype_llama"] = state_dict[t5_key].dtype
norm_keys = ["{}model.norm.weight".format(prefix), "{}model.layers.0.input_layernorm.weight".format(prefix)]
for norm_key in norm_keys:
if norm_key in state_dict:
out["dtype_llama"] = state_dict[norm_key].dtype
break
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
if quant is not None:

View File

@@ -77,6 +77,28 @@ class Qwen25_3BConfig:
rope_scale = None
final_norm: bool = True
@dataclass
class Qwen3_06BConfig:
vocab_size: int = 151936
hidden_size: int = 1024
intermediate_size: int = 3072
num_hidden_layers: int = 28
num_attention_heads: int = 16
num_key_value_heads: int = 8
max_position_embeddings: int = 32768
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
@dataclass
class Qwen3_4BConfig:
vocab_size: int = 151936
@@ -641,6 +663,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_06B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_06BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()

View File

@@ -118,9 +118,18 @@ class LTXAVTEModel(torch.nn.Module):
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
if len(sdo) == 0:
sdo = sd
missing, unexpected = self.load_state_dict(sdo, strict=False)
missing = [k for k in missing if not k.startswith("gemma3_12b.")] # filter out keys that belong to the main gemma model
return (missing, unexpected)
missing_all = []
unexpected_all = []
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]:
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
if component_sd:
missing, unexpected = component.load_state_dict(component_sd, strict=False)
missing_all.extend([f"{prefix}{k}" for k in missing])
unexpected_all.extend([f"{prefix}{k}" for k in unexpected])
return (missing_all, unexpected_all)
def memory_estimation_function(self, token_weight_pairs, device=None):
constant = 6.0

View File

@@ -61,6 +61,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return OvisTEModel_

View File

@@ -40,6 +40,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return ZImageTEModel_

View File

@@ -611,6 +611,14 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
"ff_context.net.2.weight": "txt_mlp.2.weight",
"ff_context.net.2.bias": "txt_mlp.2.bias",
"ff.linear_in.weight": "img_mlp.0.weight", # LyCoris LoKr
"ff.linear_in.bias": "img_mlp.0.bias",
"ff.linear_out.weight": "img_mlp.2.weight",
"ff.linear_out.bias": "img_mlp.2.bias",
"ff_context.linear_in.weight": "txt_mlp.0.weight",
"ff_context.linear_in.bias": "txt_mlp.0.bias",
"ff_context.linear_out.weight": "txt_mlp.2.weight",
"ff_context.linear_out.bias": "txt_mlp.2.bias",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
@@ -639,6 +647,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"proj_out.bias": "linear2.bias",
"attn.norm_q.weight": "norm.query_norm.scale",
"attn.norm_k.weight": "norm.key_norm.scale",
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
"attn.to_out.weight": "linear2.weight", # Flux 2
}
for k in block_map:

View File

@@ -1,10 +1,12 @@
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
from .video_types import VideoInput
from .video_types import VideoInput, VideoOp, SliceOp
__all__ = [
"ImageInput",
"AudioInput",
"VideoInput",
"VideoOp",
"SliceOp",
"MaskInput",
"LatentInput",
]

View File

@@ -1,11 +1,48 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from fractions import Fraction
from typing import Optional, Union, IO
import copy
import io
import av
from .._util import VideoContainer, VideoCodec, VideoComponents
class VideoOp(ABC):
"""Base class for lazy video operations."""
@abstractmethod
def apply(self, components: VideoComponents) -> VideoComponents:
pass
@abstractmethod
def compute_frame_count(self, input_frame_count: int) -> int:
pass
@dataclass(frozen=True)
class SliceOp(VideoOp):
"""Extract a range of frames from the video."""
start_frame: int
frame_count: int
def apply(self, components: VideoComponents) -> VideoComponents:
total = components.images.shape[0]
start = max(0, min(self.start_frame, total))
end = min(start + self.frame_count, total)
return VideoComponents(
images=components.images[start:end],
audio=components.audio,
frame_rate=components.frame_rate,
metadata=getattr(components, 'metadata', None),
)
def compute_frame_count(self, input_frame_count: int) -> int:
start = max(0, min(self.start_frame, input_frame_count))
return min(self.frame_count, input_frame_count - start)
class VideoInput(ABC):
"""
Abstract base class for video input types.
@@ -21,6 +58,12 @@ class VideoInput(ABC):
"""
pass
def sliced(self, start_frame: int, frame_count: int) -> "VideoInput":
"""Return a copy of this video with a slice operation appended."""
new = copy.copy(self)
new._operations = getattr(self, '_operations', []) + [SliceOp(start_frame, frame_count)]
return new
@abstractmethod
def save_to(
self,

View File

@@ -1,7 +1,8 @@
from .video_types import VideoFromFile, VideoFromComponents
from .._input import SliceOp
__all__ = [
# Implementations
"VideoFromFile",
"VideoFromComponents",
"SliceOp",
]

View File

@@ -3,7 +3,7 @@ from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction
from typing import Optional
from .._input import AudioInput, VideoInput
from .._input import AudioInput, VideoInput, VideoOp
import av
import io
import json
@@ -63,6 +63,8 @@ class VideoFromFile(VideoInput):
containing the file contents.
"""
self.__file = file
self._operations: list[VideoOp] = []
self.__materialized: Optional[VideoFromComponents] = None
def get_stream_source(self) -> str | io.BytesIO:
"""
@@ -161,6 +163,10 @@ class VideoFromFile(VideoInput):
if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
# Apply operations to get final frame count
for op in self._operations:
frame_count = op.compute_frame_count(frame_count)
return frame_count
def get_frame_rate(self) -> Fraction:
@@ -239,10 +245,18 @@ class VideoFromFile(VideoInput):
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
def get_components(self) -> VideoComponents:
if self.__materialized is not None:
return self.__materialized.get_components()
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
return self.get_components_internal(container)
components = self.get_components_internal(container)
for op in self._operations:
components = op.apply(components)
self.__materialized = VideoFromComponents(components)
self._operations = []
return components
raise ValueError(f"No video stream found in file '{self.__file}'")
def save_to(
@@ -317,14 +331,27 @@ class VideoFromComponents(VideoInput):
def __init__(self, components: VideoComponents):
self.__components = components
self._operations: list[VideoOp] = []
def get_components(self) -> VideoComponents:
if self._operations:
components = self.__components
for op in self._operations:
components = op.apply(components)
self.__components = components
self._operations = []
return VideoComponents(
images=self.__components.images,
audio=self.__components.audio,
frame_rate=self.__components.frame_rate
)
def get_frame_count(self) -> int:
count = int(self.__components.images.shape[0])
for op in self._operations:
count = op.compute_frame_count(count)
return count
def save_to(
self,
path: str,
@@ -332,6 +359,9 @@ class VideoFromComponents(VideoInput):
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
# Materialize ops before saving
components = self.get_components()
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
@@ -345,22 +375,22 @@ class VideoFromComponents(VideoInput):
for key, value in metadata.items():
output.metadata[key] = json.dumps(value)
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
frame_rate = Fraction(round(components.frame_rate * 1000), 1000)
# Create a video stream
video_stream = output.add_stream('h264', rate=frame_rate)
video_stream.width = self.__components.images.shape[2]
video_stream.height = self.__components.images.shape[1]
video_stream.width = components.images.shape[2]
video_stream.height = components.images.shape[1]
video_stream.pix_fmt = 'yuv420p'
# Create an audio stream
audio_sample_rate = 1
audio_stream: Optional[av.AudioStream] = None
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
if components.audio:
audio_sample_rate = int(components.audio['sample_rate'])
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
# Encode video
for i, frame in enumerate(self.__components.images):
for i, frame in enumerate(components.images):
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
@@ -371,10 +401,10 @@ class VideoFromComponents(VideoInput):
packet = video_stream.encode(None)
output.mux(packet)
if audio_stream and self.__components.audio:
waveform = self.__components.audio['waveform']
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
if audio_stream and components.audio:
waveform = components.audio['waveform']
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * components.images.shape[0])]
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
frame.sample_rate = audio_sample_rate
frame.pts = 0
output.mux(audio_stream.encode(frame))

View File

@@ -754,7 +754,7 @@ class AnyType(ComfyTypeIO):
Type = Any
@comfytype(io_type="MODEL_PATCH")
class MODEL_PATCH(ComfyTypeIO):
class ModelPatch(ComfyTypeIO):
Type = Any
@comfytype(io_type="AUDIO_ENCODER")
@@ -1000,20 +1000,38 @@ class Autogrow(ComfyTypeI):
names = [f"{prefix}{i}" for i in range(max)]
# need to create a new input based on the contents of input
template_input = None
for _, dict_input in input.items():
# for now, get just the first value from dict_input
template_required = True
for _input_type, dict_input in input.items():
# for now, get just the first value from dict_input; if not required, min can be ignored
if len(dict_input) == 0:
continue
template_input = list(dict_input.values())[0]
template_required = _input_type == "required"
break
if template_input is None:
raise Exception("template_input could not be determined from required or optional; this should never happen.")
new_dict = {}
new_dict_added_to = False
# first, add possible inputs into out_dict
for i, name in enumerate(names):
expected_id = finalize_prefix(curr_prefix, name)
# required
if i < min and template_required:
out_dict["required"][expected_id] = template_input
type_dict = new_dict.setdefault("required", {})
# optional
else:
out_dict["optional"][expected_id] = template_input
type_dict = new_dict.setdefault("optional", {})
if expected_id in live_inputs:
# required
if i < min:
type_dict = new_dict.setdefault("required", {})
# optional
else:
type_dict = new_dict.setdefault("optional", {})
# NOTE: prefix gets added in parse_class_inputs
type_dict[name] = template_input
new_dict_added_to = True
# account for the edge case that all inputs are optional and no values are received
if not new_dict_added_to:
finalized_prefix = finalize_prefix(curr_prefix)
out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix
out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_DICT
parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix)
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
@@ -1151,6 +1169,8 @@ class V3Data(TypedDict):
'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.'
dynamic_paths: dict[str, Any]
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
dynamic_paths_default_value: dict[str, Any]
'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.'
create_dynamic_tuple: bool
'When True, the value of the dynamic input will be in the format (value, path_key).'
@@ -1229,6 +1249,7 @@ class NodeInfoV1:
experimental: bool=None
api_node: bool=None
price_badge: dict | None = None
search_aliases: list[str]=None
@dataclass
class NodeInfoV3:
@@ -1326,6 +1347,8 @@ class Schema:
hidden: list[Hidden] = field(default_factory=list)
description: str=""
"""Node description, shown as a tooltip when hovering over the node."""
search_aliases: list[str] = field(default_factory=list)
"""Alternative names for search. Useful for synonyms, abbreviations, or old names after renaming."""
is_input_list: bool = False
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
@@ -1463,6 +1486,7 @@ class Schema:
api_node=self.is_api_node,
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
search_aliases=self.search_aliases if self.search_aliases else None,
)
return info
@@ -1504,6 +1528,7 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
"required": {},
"optional": {},
"dynamic_paths": {},
"dynamic_paths_default_value": {},
}
d = d.copy()
# ignore hidden for parsing
@@ -1513,8 +1538,12 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
out_dict["hidden"] = hidden
v3_data = {}
dynamic_paths = out_dict.pop("dynamic_paths", None)
if dynamic_paths is not None:
if dynamic_paths is not None and len(dynamic_paths) > 0:
v3_data["dynamic_paths"] = dynamic_paths
# this list is used for autogrow, in the case all inputs are optional and no values are passed
dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None)
if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0:
v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value
return out_dict, hidden, v3_data
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
@@ -1551,11 +1580,16 @@ def add_to_dict_v1(i: Input, d: dict):
def add_to_dict_v3(io: Input | Output, d: dict):
d[io.id] = (io.get_io_type(), io.as_dict())
class DynamicPathsDefaultValue:
EMPTY_DICT = "empty_dict"
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
paths = v3_data.get("dynamic_paths", None)
default_value_dict = v3_data.get("dynamic_paths_default_value", {})
if paths is None:
return values
values = values.copy()
result = {}
create_tuple = v3_data.get("create_dynamic_tuple", False)
@@ -1569,6 +1603,11 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
if is_last:
value = values.pop(key, None)
if value is None:
# see if a default value was provided for this key
default_option = default_value_dict.get(key, None)
if default_option == DynamicPathsDefaultValue.EMPTY_DICT:
value = {}
if create_tuple:
value = (value, key)
current[p] = value
@@ -1999,6 +2038,7 @@ __all__ = [
"ControlNet",
"Vae",
"Model",
"ModelPatch",
"ClipVision",
"ClipVisionOutput",
"AudioEncoder",

View File

@@ -1,65 +0,0 @@
# ComfyUI API Nodes
## Introduction
Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview).
## Development
While developing, you should be testing against the Staging environment. To test against staging:
**Install ComfyUI_frontend**
Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to start the frontend server. By default, it will connect to Staging authentication.
> **Hint:** If you use --front-end-version argument for ComfyUI, it will use production authentication.
```bash
python run main.py --comfy-api-base https://stagingapi.comfy.org
```
To authenticate to staging, please login and then ask one of Comfy Org team to whitelist you for access to staging.
API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes.
### Redocly Instructions
**Tip**
When developing locally, use the `redocly-dev.yaml` file to generate pydantic models. This lets you use stubs for APIs that are not marked `Released` yet.
Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging.
```bash
# Download the OpenAPI file from staging server.
curl -o openapi.yaml https://stagingapi.comfy.org/openapi
# Filter out unneeded API definitions.
npm install -g @redocly/cli
redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly-dev.yaml --remove-unused-components
# Generate the pydantic datamodels for validation.
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
```
# Merging to Master
Before merging to comfyanonymous/ComfyUI master, follow these steps:
1. Add the "Released" tag to the ComfyUI OpenAPI yaml file for each endpoint you are using in the nodes.
1. Make sure the ComfyUI API is deployed to prod with your changes.
1. Run the code generation again with `redocly.yaml` and the production OpenAPI yaml file.
```bash
# Download the OpenAPI file from prod server.
curl -o openapi.yaml https://api.comfy.org/openapi
# Filter out unneeded API definitions.
npm install -g @redocly/cli
redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
# Generate the pydantic datamodels for validation.
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
```

View File

@@ -0,0 +1,61 @@
from typing import TypedDict
from pydantic import BaseModel, Field
class InputModerationSettings(TypedDict):
prompt_content_moderation: bool
visual_input_moderation: bool
visual_output_moderation: bool
class BriaEditImageRequest(BaseModel):
instruction: str | None = Field(...)
structured_instruction: str | None = Field(
...,
description="Use this instead of instruction for precise, programmatic control.",
)
images: list[str] = Field(
...,
description="Required. Publicly available URL or Base64-encoded. Must contain exactly one item.",
)
mask: str | None = Field(
None,
description="Mask image (black and white). Black areas will be preserved, white areas will be edited. "
"If omitted, the edit applies to the entire image. "
"The input image and the the input mask must be of the same size.",
)
negative_prompt: str | None = Field(None)
guidance_scale: float = Field(...)
model_version: str = Field(...)
steps_num: int = Field(...)
seed: int = Field(...)
ip_signal: bool = Field(
False,
description="If true, returns a warning for potential IP content in the instruction.",
)
prompt_content_moderation: bool = Field(
False, description="If true, returns 422 on instruction moderation failure."
)
visual_input_content_moderation: bool = Field(
False, description="If true, returns 422 on images or mask moderation failure."
)
visual_output_content_moderation: bool = Field(
False, description="If true, returns 422 on visual output moderation failure."
)
class BriaStatusResponse(BaseModel):
request_id: str = Field(...)
status_url: str = Field(...)
warning: str | None = Field(None)
class BriaResult(BaseModel):
structured_prompt: str = Field(...)
image_url: str = Field(...)
class BriaResponse(BaseModel):
status: str = Field(...)
result: BriaResult | None = Field(None)

View File

@@ -0,0 +1,292 @@
from enum import Enum
from typing import Optional, List, Dict, Any, Union
from datetime import datetime
from pydantic import BaseModel, Field, RootModel, StrictBytes
class IdeogramColorPalette1(BaseModel):
name: str = Field(..., description='Name of the preset color palette')
class Member(BaseModel):
color: Optional[str] = Field(
None, description='Hexadecimal color code', pattern='^#[0-9A-Fa-f]{6}$'
)
weight: Optional[float] = Field(
None, description='Optional weight for the color (0-1)', ge=0.0, le=1.0
)
class IdeogramColorPalette2(BaseModel):
members: List[Member] = Field(
..., description='Array of color definitions with optional weights'
)
class IdeogramColorPalette(
RootModel[Union[IdeogramColorPalette1, IdeogramColorPalette2]]
):
root: Union[IdeogramColorPalette1, IdeogramColorPalette2] = Field(
...,
description='A color palette specification that can either use a preset name or explicit color definitions with weights',
)
class ImageRequest(BaseModel):
aspect_ratio: Optional[str] = Field(
None,
description="Optional. The aspect ratio (e.g., 'ASPECT_16_9', 'ASPECT_1_1'). Cannot be used with resolution. Defaults to 'ASPECT_1_1' if unspecified.",
)
color_palette: Optional[Dict[str, Any]] = Field(
None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.'
)
magic_prompt_option: Optional[str] = Field(
None, description="Optional. MagicPrompt usage ('AUTO', 'ON', 'OFF')."
)
model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')")
negative_prompt: Optional[str] = Field(
None,
description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.',
)
num_images: Optional[int] = Field(
1,
description='Optional. Number of images to generate (1-8). Defaults to 1.',
ge=1,
le=8,
)
prompt: str = Field(
..., description='Required. The prompt to use to generate the image.'
)
resolution: Optional[str] = Field(
None,
description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.",
)
seed: Optional[int] = Field(
None,
description='Optional. A number between 0 and 2147483647.',
ge=0,
le=2147483647,
)
style_type: Optional[str] = Field(
None,
description="Optional. Style type ('AUTO', 'GENERAL', 'REALISTIC', 'DESIGN', 'RENDER_3D', 'ANIME'). Only for models V_2 and above.",
)
class IdeogramGenerateRequest(BaseModel):
image_request: ImageRequest = Field(
..., description='The image generation request parameters.'
)
class Datum(BaseModel):
is_image_safe: Optional[bool] = Field(
None, description='Indicates whether the image is considered safe.'
)
prompt: Optional[str] = Field(
None, description='The prompt used to generate this image.'
)
resolution: Optional[str] = Field(
None, description="The resolution of the generated image (e.g., '1024x1024')."
)
seed: Optional[int] = Field(
None, description='The seed value used for this generation.'
)
style_type: Optional[str] = Field(
None,
description="The style type used for generation (e.g., 'REALISTIC', 'ANIME').",
)
url: Optional[str] = Field(None, description='URL to the generated image.')
class IdeogramGenerateResponse(BaseModel):
created: Optional[datetime] = Field(
None, description='Timestamp when the generation was created.'
)
data: Optional[List[Datum]] = Field(
None, description='Array of generated image information.'
)
class StyleCode(RootModel[str]):
root: str = Field(..., pattern='^[0-9A-Fa-f]{8}$')
class Datum1(BaseModel):
is_image_safe: Optional[bool] = None
prompt: Optional[str] = None
resolution: Optional[str] = None
seed: Optional[int] = None
style_type: Optional[str] = None
url: Optional[str] = None
class IdeogramV3IdeogramResponse(BaseModel):
created: Optional[datetime] = None
data: Optional[List[Datum1]] = None
class RenderingSpeed1(str, Enum):
TURBO = 'TURBO'
DEFAULT = 'DEFAULT'
QUALITY = 'QUALITY'
class IdeogramV3ReframeRequest(BaseModel):
color_palette: Optional[Dict[str, Any]] = None
image: Optional[StrictBytes] = None
num_images: Optional[int] = Field(None, ge=1, le=8)
rendering_speed: Optional[RenderingSpeed1] = None
resolution: str
seed: Optional[int] = Field(None, ge=0, le=2147483647)
style_codes: Optional[List[str]] = None
style_reference_images: Optional[List[StrictBytes]] = None
class MagicPrompt(str, Enum):
AUTO = 'AUTO'
ON = 'ON'
OFF = 'OFF'
class StyleType(str, Enum):
AUTO = 'AUTO'
GENERAL = 'GENERAL'
REALISTIC = 'REALISTIC'
DESIGN = 'DESIGN'
class IdeogramV3RemixRequest(BaseModel):
aspect_ratio: Optional[str] = None
color_palette: Optional[Dict[str, Any]] = None
image: Optional[StrictBytes] = None
image_weight: Optional[int] = Field(50, ge=1, le=100)
magic_prompt: Optional[MagicPrompt] = None
negative_prompt: Optional[str] = None
num_images: Optional[int] = Field(None, ge=1, le=8)
prompt: str
rendering_speed: Optional[RenderingSpeed1] = None
resolution: Optional[str] = None
seed: Optional[int] = Field(None, ge=0, le=2147483647)
style_codes: Optional[List[str]] = None
style_reference_images: Optional[List[StrictBytes]] = None
style_type: Optional[StyleType] = None
class IdeogramV3ReplaceBackgroundRequest(BaseModel):
color_palette: Optional[Dict[str, Any]] = None
image: Optional[StrictBytes] = None
magic_prompt: Optional[MagicPrompt] = None
num_images: Optional[int] = Field(None, ge=1, le=8)
prompt: str
rendering_speed: Optional[RenderingSpeed1] = None
seed: Optional[int] = Field(None, ge=0, le=2147483647)
style_codes: Optional[List[str]] = None
style_reference_images: Optional[List[StrictBytes]] = None
class ColorPalette(BaseModel):
name: str = Field(..., description='Name of the color palette', examples=['PASTEL'])
class MagicPrompt2(str, Enum):
ON = 'ON'
OFF = 'OFF'
class StyleType1(str, Enum):
AUTO = 'AUTO'
GENERAL = 'GENERAL'
REALISTIC = 'REALISTIC'
DESIGN = 'DESIGN'
FICTION = 'FICTION'
class RenderingSpeed(str, Enum):
DEFAULT = 'DEFAULT'
TURBO = 'TURBO'
QUALITY = 'QUALITY'
class IdeogramV3EditRequest(BaseModel):
color_palette: Optional[IdeogramColorPalette] = None
image: Optional[StrictBytes] = Field(
None,
description='The image being edited (max size 10MB); only JPEG, WebP and PNG formats are supported at this time.',
)
magic_prompt: Optional[str] = Field(
None,
description='Determine if MagicPrompt should be used in generating the request or not.',
)
mask: Optional[StrictBytes] = Field(
None,
description='A black and white image of the same size as the image being edited (max size 10MB). Black regions in the mask should match up with the regions of the image that you would like to edit; only JPEG, WebP and PNG formats are supported at this time.',
)
num_images: Optional[int] = Field(
None, description='The number of images to generate.'
)
prompt: str = Field(
..., description='The prompt used to describe the edited result.'
)
rendering_speed: RenderingSpeed
seed: Optional[int] = Field(
None, description='Random seed. Set for reproducible generation.'
)
style_codes: Optional[List[StyleCode]] = Field(
None,
description='A list of 8 character hexadecimal codes representing the style of the image. Cannot be used in conjunction with style_reference_images or style_type.',
)
style_reference_images: Optional[List[StrictBytes]] = Field(
None,
description='A set of images to use as style references (maximum total size 10MB across all style references). The images should be in JPEG, PNG or WebP format.',
)
character_reference_images: Optional[List[str]] = Field(
None,
description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.'
)
character_reference_images_mask: Optional[List[str]] = Field(
None,
description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.'
)
class IdeogramV3Request(BaseModel):
aspect_ratio: Optional[str] = Field(
None, description='Aspect ratio in format WxH', examples=['1x3']
)
color_palette: Optional[ColorPalette] = None
magic_prompt: Optional[MagicPrompt2] = Field(
None, description='Whether to enable magic prompt enhancement'
)
negative_prompt: Optional[str] = Field(
None, description='Text prompt specifying what to avoid in the generation'
)
num_images: Optional[int] = Field(
None, description='Number of images to generate', ge=1
)
prompt: str = Field(..., description='The text prompt for image generation')
rendering_speed: RenderingSpeed
resolution: Optional[str] = Field(
None, description='Image resolution in format WxH', examples=['1280x800']
)
seed: Optional[int] = Field(
None, description='Seed value for reproducible generation'
)
style_codes: Optional[List[StyleCode]] = Field(
None, description='Array of style codes in hexadecimal format'
)
style_reference_images: Optional[List[str]] = Field(
None, description='Array of reference image URLs or identifiers'
)
style_type: Optional[StyleType1] = Field(
None, description='The type of style to apply'
)
character_reference_images: Optional[List[str]] = Field(
None,
description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.'
)
character_reference_images_mask: Optional[List[str]] = Field(
None,
description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.'
)

View File

@@ -0,0 +1,152 @@
from enum import Enum
from typing import Optional, Dict, Any
from pydantic import BaseModel, Field, StrictBytes
class MoonvalleyPromptResponse(BaseModel):
error: Optional[Dict[str, Any]] = None
frame_conditioning: Optional[Dict[str, Any]] = None
id: Optional[str] = None
inference_params: Optional[Dict[str, Any]] = None
meta: Optional[Dict[str, Any]] = None
model_params: Optional[Dict[str, Any]] = None
output_url: Optional[str] = None
prompt_text: Optional[str] = None
status: Optional[str] = None
class MoonvalleyTextToVideoInferenceParams(BaseModel):
add_quality_guidance: Optional[bool] = Field(
True, description='Whether to add quality guidance'
)
caching_coefficient: Optional[float] = Field(
0.3, description='Caching coefficient for optimization'
)
caching_cooldown: Optional[int] = Field(
3, description='Number of caching cooldown steps'
)
caching_warmup: Optional[int] = Field(
3, description='Number of caching warmup steps'
)
clip_value: Optional[float] = Field(
3, description='CLIP value for generation control'
)
conditioning_frame_index: Optional[int] = Field(
0, description='Index of the conditioning frame'
)
cooldown_steps: Optional[int] = Field(
75, description='Number of cooldown steps (calculated based on num_frames)'
)
fps: Optional[int] = Field(
24, description='Frames per second of the generated video'
)
guidance_scale: Optional[float] = Field(
10, description='Guidance scale for generation control'
)
height: Optional[int] = Field(
1080, description='Height of the generated video in pixels'
)
negative_prompt: Optional[str] = Field(None, description='Negative prompt text')
num_frames: Optional[int] = Field(64, description='Number of frames to generate')
seed: Optional[int] = Field(
None, description='Random seed for generation (default: random)'
)
shift_value: Optional[float] = Field(
3, description='Shift value for generation control'
)
steps: Optional[int] = Field(80, description='Number of denoising steps')
use_guidance_schedule: Optional[bool] = Field(
True, description='Whether to use guidance scheduling'
)
use_negative_prompts: Optional[bool] = Field(
False, description='Whether to use negative prompts'
)
use_timestep_transform: Optional[bool] = Field(
True, description='Whether to use timestep transformation'
)
warmup_steps: Optional[int] = Field(
0, description='Number of warmup steps (calculated based on num_frames)'
)
width: Optional[int] = Field(
1920, description='Width of the generated video in pixels'
)
class MoonvalleyTextToVideoRequest(BaseModel):
image_url: Optional[str] = None
inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None
prompt_text: Optional[str] = None
webhook_url: Optional[str] = None
class MoonvalleyUploadFileRequest(BaseModel):
file: Optional[StrictBytes] = None
class MoonvalleyUploadFileResponse(BaseModel):
access_url: Optional[str] = None
class MoonvalleyVideoToVideoInferenceParams(BaseModel):
add_quality_guidance: Optional[bool] = Field(
True, description='Whether to add quality guidance'
)
caching_coefficient: Optional[float] = Field(
0.3, description='Caching coefficient for optimization'
)
caching_cooldown: Optional[int] = Field(
3, description='Number of caching cooldown steps'
)
caching_warmup: Optional[int] = Field(
3, description='Number of caching warmup steps'
)
clip_value: Optional[float] = Field(
3, description='CLIP value for generation control'
)
conditioning_frame_index: Optional[int] = Field(
0, description='Index of the conditioning frame'
)
cooldown_steps: Optional[int] = Field(
36, description='Number of cooldown steps (calculated based on num_frames)'
)
guidance_scale: Optional[float] = Field(
15, description='Guidance scale for generation control'
)
negative_prompt: Optional[str] = Field(None, description='Negative prompt text')
seed: Optional[int] = Field(
None, description='Random seed for generation (default: random)'
)
shift_value: Optional[float] = Field(
3, description='Shift value for generation control'
)
steps: Optional[int] = Field(80, description='Number of denoising steps')
use_guidance_schedule: Optional[bool] = Field(
True, description='Whether to use guidance scheduling'
)
use_negative_prompts: Optional[bool] = Field(
False, description='Whether to use negative prompts'
)
use_timestep_transform: Optional[bool] = Field(
True, description='Whether to use timestep transformation'
)
warmup_steps: Optional[int] = Field(
24, description='Number of warmup steps (calculated based on num_frames)'
)
class ControlType(str, Enum):
motion_control = 'motion_control'
pose_control = 'pose_control'
class MoonvalleyVideoToVideoRequest(BaseModel):
control_type: ControlType = Field(
..., description='Supported types for video control'
)
inference_params: Optional[MoonvalleyVideoToVideoInferenceParams] = None
prompt_text: str = Field(..., description='Describes the video to generate')
video_url: str = Field(..., description='Url to control video')
webhook_url: Optional[str] = Field(
None, description='Optional webhook URL for notifications'
)

View File

@@ -0,0 +1,170 @@
from pydantic import BaseModel, Field
class Datum2(BaseModel):
b64_json: str | None = Field(None, description="Base64 encoded image data")
revised_prompt: str | None = Field(None, description="Revised prompt")
url: str | None = Field(None, description="URL of the image")
class InputTokensDetails(BaseModel):
image_tokens: int | None = Field(None)
text_tokens: int | None = Field(None)
class Usage(BaseModel):
input_tokens: int | None = Field(None)
input_tokens_details: InputTokensDetails | None = Field(None)
output_tokens: int | None = Field(None)
total_tokens: int | None = Field(None)
class OpenAIImageGenerationResponse(BaseModel):
data: list[Datum2] | None = Field(None)
usage: Usage | None = Field(None)
class OpenAIImageEditRequest(BaseModel):
background: str | None = Field(None, description="Background transparency")
model: str = Field(...)
moderation: str | None = Field(None)
n: int | None = Field(None, description="The number of images to generate")
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
output_format: str | None = Field(None)
prompt: str = Field(...)
quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
size: str | None = Field(None, description="Size of the output image")
class OpenAIImageGenerationRequest(BaseModel):
background: str | None = Field(None, description="Background transparency")
model: str | None = Field(None)
moderation: str | None = Field(None)
n: int | None = Field(
None,
description="The number of images to generate.",
)
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
output_format: str | None = Field(None)
prompt: str = Field(...)
quality: str | None = Field(None, description="The quality of the generated image")
size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
style: str | None = Field(None, description="Style of the image (only for dall-e-3)")
class ModelResponseProperties(BaseModel):
instructions: str | None = Field(None)
max_output_tokens: int | None = Field(None)
model: str | None = Field(None)
temperature: float | None = Field(1, description="Controls randomness in the response", ge=0.0, le=2.0)
top_p: float | None = Field(
1,
description="Controls diversity of the response via nucleus sampling",
ge=0.0,
le=1.0,
)
truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'")
class ResponseProperties(BaseModel):
instructions: str | None = Field(None)
max_output_tokens: int | None = Field(None)
model: str | None = Field(None)
previous_response_id: str | None = Field(None)
truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'")
class ResponseError(BaseModel):
code: str = Field(...)
message: str = Field(...)
class OutputTokensDetails(BaseModel):
reasoning_tokens: int = Field(..., description="The number of reasoning tokens.")
class CachedTokensDetails(BaseModel):
cached_tokens: int = Field(
...,
description="The number of tokens that were retrieved from the cache.",
)
class ResponseUsage(BaseModel):
input_tokens: int = Field(..., description="The number of input tokens.")
input_tokens_details: CachedTokensDetails = Field(...)
output_tokens: int = Field(..., description="The number of output tokens.")
output_tokens_details: OutputTokensDetails = Field(...)
total_tokens: int = Field(..., description="The total number of tokens used.")
class InputTextContent(BaseModel):
text: str = Field(..., description="The text input to the model.")
type: str = Field("input_text")
class OutputContent(BaseModel):
type: str = Field(..., description="The type of output content")
text: str | None = Field(None, description="The text content")
data: str | None = Field(None, description="Base64-encoded audio data")
transcript: str | None = Field(None, description="Transcript of the audio")
class OutputMessage(BaseModel):
type: str = Field(..., description="The type of output item")
content: list[OutputContent] | None = Field(None, description="The content of the message")
role: str | None = Field(None, description="The role of the message")
class OpenAIResponse(ModelResponseProperties, ResponseProperties):
created_at: float | None = Field(
None,
description="Unix timestamp (in seconds) of when this Response was created.",
)
error: ResponseError | None = Field(None)
id: str | None = Field(None, description="Unique identifier for this Response.")
object: str | None = Field(None, description="The object type of this resource - always set to `response`.")
output: list[OutputMessage] | None = Field(None)
parallel_tool_calls: bool | None = Field(True)
status: str | None = Field(
None,
description="One of `completed`, `failed`, `in_progress`, or `incomplete`.",
)
usage: ResponseUsage | None = Field(None)
class InputImageContent(BaseModel):
detail: str = Field(..., description="One of `high`, `low`, or `auto`. Defaults to `auto`.")
file_id: str | None = Field(None)
image_url: str | None = Field(None)
type: str = Field(..., description="The type of the input item. Always `input_image`.")
class InputFileContent(BaseModel):
file_data: str | None = Field(None)
file_id: str | None = Field(None)
filename: str | None = Field(None, description="The name of the file to be sent to the model.")
type: str = Field(..., description="The type of the input item. Always `input_file`.")
class InputMessage(BaseModel):
content: list[InputTextContent | InputImageContent | InputFileContent] = Field(
...,
description="A list of one or many input items to the model, containing different content types.",
)
role: str | None = Field(None)
type: str | None = Field(None)
class OpenAICreateResponse(ModelResponseProperties, ResponseProperties):
include: str | None = Field(None)
input: list[InputMessage] = Field(...)
parallel_tool_calls: bool | None = Field(
True, description="Whether to allow the model to run tool calls in parallel."
)
store: bool | None = Field(
True,
description="Whether to store the generated model response for later retrieval via API.",
)
stream: bool | None = Field(False)
usage: ResponseUsage | None = Field(None)

View File

@@ -1,52 +0,0 @@
from pydantic import BaseModel, Field
class Datum2(BaseModel):
b64_json: str | None = Field(None, description="Base64 encoded image data")
revised_prompt: str | None = Field(None, description="Revised prompt")
url: str | None = Field(None, description="URL of the image")
class InputTokensDetails(BaseModel):
image_tokens: int | None = None
text_tokens: int | None = None
class Usage(BaseModel):
input_tokens: int | None = None
input_tokens_details: InputTokensDetails | None = None
output_tokens: int | None = None
total_tokens: int | None = None
class OpenAIImageGenerationResponse(BaseModel):
data: list[Datum2] | None = None
usage: Usage | None = None
class OpenAIImageEditRequest(BaseModel):
background: str | None = Field(None, description="Background transparency")
model: str = Field(...)
moderation: str | None = Field(None)
n: int | None = Field(None, description="The number of images to generate")
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
output_format: str | None = Field(None)
prompt: str = Field(...)
quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
size: str | None = Field(None, description="Size of the output image")
class OpenAIImageGenerationRequest(BaseModel):
background: str | None = Field(None, description="Background transparency")
model: str | None = Field(None)
moderation: str | None = Field(None)
n: int | None = Field(
None,
description="The number of images to generate.",
)
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
output_format: str | None = Field(None)
prompt: str = Field(...)
quality: str | None = Field(None, description="The quality of the generated image")
size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
style: str | None = Field(None, description="Style of the image (only for dall-e-3)")

View File

@@ -0,0 +1,127 @@
from enum import Enum
from typing import Optional, List, Union
from datetime import datetime
from pydantic import BaseModel, Field, RootModel
class RunwayAspectRatioEnum(str, Enum):
field_1280_720 = '1280:720'
field_720_1280 = '720:1280'
field_1104_832 = '1104:832'
field_832_1104 = '832:1104'
field_960_960 = '960:960'
field_1584_672 = '1584:672'
field_1280_768 = '1280:768'
field_768_1280 = '768:1280'
class Position(str, Enum):
first = 'first'
last = 'last'
class RunwayPromptImageDetailedObject(BaseModel):
position: Position = Field(
...,
description="The position of the image in the output video. 'last' is currently supported for gen3a_turbo only.",
)
uri: str = Field(
..., description='A HTTPS URL or data URI containing an encoded image.'
)
class RunwayPromptImageObject(
RootModel[Union[str, List[RunwayPromptImageDetailedObject]]]
):
root: Union[str, List[RunwayPromptImageDetailedObject]] = Field(
...,
description='Image(s) to use for the video generation. Can be a single URI or an array of image objects with positions.',
)
class RunwayModelEnum(str, Enum):
gen4_turbo = 'gen4_turbo'
gen3a_turbo = 'gen3a_turbo'
class RunwayDurationEnum(int, Enum):
integer_5 = 5
integer_10 = 10
class RunwayImageToVideoRequest(BaseModel):
duration: RunwayDurationEnum
model: RunwayModelEnum
promptImage: RunwayPromptImageObject
promptText: Optional[str] = Field(
None, description='Text prompt for the generation', max_length=1000
)
ratio: RunwayAspectRatioEnum
seed: int = Field(
..., description='Random seed for generation', ge=0, le=4294967295
)
class RunwayImageToVideoResponse(BaseModel):
id: Optional[str] = Field(None, description='Task ID')
class RunwayTaskStatusEnum(str, Enum):
SUCCEEDED = 'SUCCEEDED'
RUNNING = 'RUNNING'
FAILED = 'FAILED'
PENDING = 'PENDING'
CANCELLED = 'CANCELLED'
THROTTLED = 'THROTTLED'
class RunwayTaskStatusResponse(BaseModel):
createdAt: datetime = Field(..., description='Task creation timestamp')
id: str = Field(..., description='Task ID')
output: Optional[List[str]] = Field(None, description='Array of output video URLs')
progress: Optional[float] = Field(
None,
description='Float value between 0 and 1 representing the progress of the task. Only available if status is RUNNING.',
ge=0.0,
le=1.0,
)
status: RunwayTaskStatusEnum
class Model4(str, Enum):
gen4_image = 'gen4_image'
class ReferenceImage(BaseModel):
uri: Optional[str] = Field(
None, description='A HTTPS URL or data URI containing an encoded image'
)
class RunwayTextToImageAspectRatioEnum(str, Enum):
field_1920_1080 = '1920:1080'
field_1080_1920 = '1080:1920'
field_1024_1024 = '1024:1024'
field_1360_768 = '1360:768'
field_1080_1080 = '1080:1080'
field_1168_880 = '1168:880'
field_1440_1080 = '1440:1080'
field_1080_1440 = '1080:1440'
field_1808_768 = '1808:768'
field_2112_912 = '2112:912'
class RunwayTextToImageRequest(BaseModel):
model: Model4 = Field(..., description='Model to use for generation')
promptText: str = Field(
..., description='Text prompt for the image generation', max_length=1000
)
ratio: RunwayTextToImageAspectRatioEnum
referenceImages: Optional[List[ReferenceImage]] = Field(
None, description='Array of reference images to guide the generation'
)
class RunwayTextToImageResponse(BaseModel):
id: Optional[str] = Field(None, description='Task ID')

View File

@@ -41,7 +41,7 @@ class Resolution(BaseModel):
height: int = Field(...)
class CreateCreateVideoRequestSource(BaseModel):
class CreateVideoRequestSource(BaseModel):
container: str = Field(...)
size: int = Field(..., description="Size of the video file in bytes")
duration: int = Field(..., description="Duration of the video file in seconds")
@@ -89,7 +89,7 @@ class Overrides(BaseModel):
class CreateVideoRequest(BaseModel):
source: CreateCreateVideoRequestSource = Field(...)
source: CreateVideoRequestSource = Field(...)
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...)
output: OutputInformationVideo = Field(...)
overrides: Overrides = Field(Overrides(isPaidDiffusion=True))

View File

@@ -0,0 +1,35 @@
from pydantic import BaseModel, Field
class SeedVR2ImageRequest(BaseModel):
image: str = Field(...)
target_resolution: str = Field(...)
output_format: str = Field("png")
enable_sync_mode: bool = Field(False)
class FlashVSRRequest(BaseModel):
target_resolution: str = Field(...)
video: str = Field(...)
duration: float = Field(...)
class TaskCreatedDataResponse(BaseModel):
id: str = Field(...)
class TaskCreatedResponse(BaseModel):
code: int = Field(...)
message: str = Field(...)
data: TaskCreatedDataResponse | None = Field(None)
class TaskResultDataResponse(BaseModel):
status: str = Field(...)
outputs: list[str] = Field([])
class TaskResultResponse(BaseModel):
code: int = Field(...)
message: str = Field(...)
data: TaskResultDataResponse | None = Field(None)

View File

@@ -1,10 +0,0 @@
import av
ver = av.__version__.split(".")
if int(ver[0]) < 14:
raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.")
if int(ver[0]) == 14 and int(ver[1]) < 2:
raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.")
NODE_CLASS_MAPPINGS = {}

View File

@@ -1,116 +0,0 @@
from enum import Enum
from pydantic.fields import FieldInfo
from pydantic import BaseModel
from pydantic_core import PydanticUndefined
from comfy.comfy_types.node_typing import IO, InputTypeOptions
NodeInput = tuple[IO, InputTypeOptions]
def _create_base_config(field_info: FieldInfo) -> InputTypeOptions:
config = {}
if hasattr(field_info, "default") and field_info.default is not PydanticUndefined:
config["default"] = field_info.default
if hasattr(field_info, "description") and field_info.description is not None:
config["tooltip"] = field_info.description
return config
def _get_number_constraints_config(field_info: FieldInfo) -> dict:
config = {}
if hasattr(field_info, "metadata"):
metadata = field_info.metadata
for constraint in metadata:
if hasattr(constraint, "ge"):
config["min"] = constraint.ge
if hasattr(constraint, "le"):
config["max"] = constraint.le
if hasattr(constraint, "multiple_of"):
config["step"] = constraint.multiple_of
return config
def _model_field_to_image_input(field_info: FieldInfo, **kwargs) -> NodeInput:
return IO.IMAGE, {
**_create_base_config(field_info),
**kwargs,
}
def _model_field_to_string_input(field_info: FieldInfo, **kwargs) -> NodeInput:
return IO.STRING, {
**_create_base_config(field_info),
**kwargs,
}
def _model_field_to_float_input(field_info: FieldInfo, **kwargs) -> NodeInput:
return IO.FLOAT, {
**_create_base_config(field_info),
**_get_number_constraints_config(field_info),
**kwargs,
}
def _model_field_to_int_input(field_info: FieldInfo, **kwargs) -> NodeInput:
return IO.INT, {
**_create_base_config(field_info),
**_get_number_constraints_config(field_info),
**kwargs,
}
def _model_field_to_combo_input(
field_info: FieldInfo, enum_type: type[Enum] = None, **kwargs
) -> NodeInput:
combo_config = {}
if enum_type is not None:
combo_config["options"] = [option.value for option in enum_type]
combo_config = {
**combo_config,
**_create_base_config(field_info),
**kwargs,
}
return IO.COMBO, combo_config
def model_field_to_node_input(
input_type: IO, base_model: type[BaseModel], field_name: str, **kwargs
) -> NodeInput:
"""
Maps a field from a Pydantic model to a Comfy node input.
Args:
input_type: The type of the input.
base_model: The Pydantic model to map the field from.
field_name: The name of the field to map.
**kwargs: Additional key/values to include in the input options.
Note:
For combo inputs, pass an `Enum` to the `enum_type` keyword argument to populate the options automatically.
Example:
>>> model_field_to_node_input(IO.STRING, MyModel, "my_field", multiline=True)
>>> model_field_to_node_input(IO.COMBO, MyModel, "my_field", enum_type=MyEnum)
>>> model_field_to_node_input(IO.FLOAT, MyModel, "my_field", slider=True)
"""
field_info: FieldInfo = base_model.model_fields[field_name]
result: NodeInput
if input_type == IO.IMAGE:
result = _model_field_to_image_input(field_info, **kwargs)
elif input_type == IO.STRING:
result = _model_field_to_string_input(field_info, **kwargs)
elif input_type == IO.FLOAT:
result = _model_field_to_float_input(field_info, **kwargs)
elif input_type == IO.INT:
result = _model_field_to_int_input(field_info, **kwargs)
elif input_type == IO.COMBO:
result = _model_field_to_combo_input(field_info, **kwargs)
else:
message = f"Invalid input type: {input_type}"
raise ValueError(message)
return result

View File

@@ -3,7 +3,7 @@ from pydantic import BaseModel
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bfl_api import (
from comfy_api_nodes.apis.bfl import (
BFLFluxExpandImageRequest,
BFLFluxFillImageRequest,
BFLFluxKontextProGenerateRequest,

View File

@@ -0,0 +1,198 @@
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bria import (
BriaEditImageRequest,
BriaResponse,
BriaStatusResponse,
InputModerationSettings,
)
from comfy_api_nodes.util import (
ApiEndpoint,
convert_mask_to_image,
download_url_to_image_tensor,
get_number_of_images,
poll_op,
sync_op,
upload_images_to_comfyapi,
)
class BriaImageEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BriaImageEditNode",
display_name="Bria FIBO Image Edit",
category="api node/image/Bria",
description="Edit images using Bria latest model",
inputs=[
IO.Combo.Input("model", options=["FIBO"]),
IO.Image.Input("image"),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Instruction to edit image",
),
IO.String.Input("negative_prompt", multiline=True, default=""),
IO.String.Input(
"structured_prompt",
multiline=True,
default="",
tooltip="A string containing the structured edit prompt in JSON format. "
"Use this instead of usual prompt for precise, programmatic control.",
),
IO.Int.Input(
"seed",
default=1,
min=1,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
),
IO.Float.Input(
"guidance_scale",
default=3,
min=3,
max=5,
step=0.01,
display_mode=IO.NumberDisplay.number,
tooltip="Higher value makes the image follow the prompt more closely.",
),
IO.Int.Input(
"steps",
default=50,
min=20,
max=50,
step=1,
display_mode=IO.NumberDisplay.number,
),
IO.DynamicCombo.Input(
"moderation",
options=[
IO.DynamicCombo.Option(
"true",
[
IO.Boolean.Input(
"prompt_content_moderation", default=False
),
IO.Boolean.Input(
"visual_input_moderation", default=False
),
IO.Boolean.Input(
"visual_output_moderation", default=True
),
],
),
IO.DynamicCombo.Option("false", []),
],
tooltip="Moderation settings",
),
IO.Mask.Input(
"mask",
tooltip="If omitted, the edit applies to the entire image.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
IO.String.Output(display_name="structured_prompt"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.04}""",
),
)
@classmethod
async def execute(
cls,
model: str,
image: Input.Image,
prompt: str,
negative_prompt: str,
structured_prompt: str,
seed: int,
guidance_scale: float,
steps: int,
moderation: InputModerationSettings,
mask: Input.Image | None = None,
) -> IO.NodeOutput:
if not prompt and not structured_prompt:
raise ValueError(
"One of prompt or structured_prompt is required to be non-empty."
)
if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.")
mask_url = None
if mask is not None:
mask_url = (
await upload_images_to_comfyapi(
cls,
convert_mask_to_image(mask),
max_images=1,
mime_type="image/png",
wait_label="Uploading mask",
)
)[0]
response = await sync_op(
cls,
ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"),
data=BriaEditImageRequest(
instruction=prompt if prompt else None,
structured_instruction=structured_prompt if structured_prompt else None,
images=await upload_images_to_comfyapi(
cls,
image,
max_images=1,
mime_type="image/png",
wait_label="Uploading image",
),
mask=mask_url,
negative_prompt=negative_prompt if negative_prompt else None,
guidance_scale=guidance_scale,
seed=seed,
model_version=model,
steps_num=steps,
prompt_content_moderation=moderation.get(
"prompt_content_moderation", False
),
visual_input_content_moderation=moderation.get(
"visual_input_moderation", False
),
visual_output_content_moderation=moderation.get(
"visual_output_moderation", False
),
),
response_model=BriaStatusResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
response_model=BriaResponse,
)
return IO.NodeOutput(
await download_url_to_image_tensor(response.result.image_url),
response.result.structured_prompt,
)
class BriaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
BriaImageEditNode,
]
async def comfy_entrypoint() -> BriaExtension:
return BriaExtension()

View File

@@ -5,7 +5,7 @@ import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bytedance_api import (
from comfy_api_nodes.apis.bytedance import (
RECOMMENDED_PRESETS,
RECOMMENDED_PRESETS_SEEDREAM_4,
VIDEO_TASKS_EXECUTION_TIME,

View File

@@ -14,7 +14,7 @@ from typing_extensions import override
import folder_paths
from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.gemini_api import (
from comfy_api_nodes.apis.gemini import (
GeminiContent,
GeminiFileData,
GeminiGenerateContentRequest,

View File

@@ -4,7 +4,7 @@ from comfy_api.latest import IO, ComfyExtension
from PIL import Image
import numpy as np
import torch
from comfy_api_nodes.apis import (
from comfy_api_nodes.apis.ideogram import (
IdeogramGenerateRequest,
IdeogramGenerateResponse,
ImageRequest,

View File

@@ -49,7 +49,7 @@ from comfy_api_nodes.apis import (
KlingCharacterEffectModelName,
KlingSingleImageEffectModelName,
)
from comfy_api_nodes.apis.kling_api import (
from comfy_api_nodes.apis.kling import (
ImageToVideoWithAudioRequest,
MotionControlRequest,
OmniImageParamImage,

View File

@@ -4,7 +4,7 @@ import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.luma_api import (
from comfy_api_nodes.apis.luma import (
LumaAspectRatio,
LumaCharacterRef,
LumaConceptChain,

View File

@@ -4,7 +4,7 @@ import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.minimax_api import (
from comfy_api_nodes.apis.minimax import (
MinimaxFileRetrieveResponse,
MiniMaxModel,
MinimaxTaskResultResponse,

View File

@@ -3,7 +3,7 @@ import logging
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis import (
from comfy_api_nodes.apis.moonvalley import (
MoonvalleyPromptResponse,
MoonvalleyTextToVideoInferenceParams,
MoonvalleyTextToVideoRequest,

View File

@@ -10,24 +10,18 @@ from typing_extensions import override
import folder_paths
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis import (
CreateModelResponseProperties,
Detail,
InputContent,
from comfy_api_nodes.apis.openai import (
InputFileContent,
InputImageContent,
InputMessage,
InputMessageContentList,
InputTextContent,
Item,
ModelResponseProperties,
OpenAICreateResponse,
OpenAIResponse,
OutputContent,
)
from comfy_api_nodes.apis.openai_api import (
OpenAIImageEditRequest,
OpenAIImageGenerationRequest,
OpenAIImageGenerationResponse,
OpenAIResponse,
OutputContent,
)
from comfy_api_nodes.util import (
ApiEndpoint,
@@ -266,7 +260,7 @@ class OpenAIDalle3(IO.ComfyNode):
"seed",
default=0,
min=0,
max=2 ** 31 - 1,
max=2**31 - 1,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
@@ -370,9 +364,9 @@ class OpenAIGPTImage1(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="OpenAIGPTImage1",
display_name="OpenAI GPT Image 1",
display_name="OpenAI GPT Image 1.5",
category="api node/image/OpenAI",
description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.",
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
inputs=[
IO.String.Input(
"prompt",
@@ -384,7 +378,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
"seed",
default=0,
min=0,
max=2 ** 31 - 1,
max=2**31 - 1,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
@@ -435,6 +429,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
IO.Combo.Input(
"model",
options=["gpt-image-1", "gpt-image-1.5"],
default="gpt-image-1.5",
optional=True,
),
],
@@ -500,8 +495,8 @@ class OpenAIGPTImage1(IO.ComfyNode):
files = []
batch_size = image.shape[0]
for i in range(batch_size):
single_image = image[i: i + 1]
scaled_image = downscale_image_tensor(single_image, total_pixels=2048*2048).squeeze()
single_image = image[i : i + 1]
scaled_image = downscale_image_tensor(single_image, total_pixels=2048 * 2048).squeeze()
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
@@ -523,7 +518,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
rgba_mask = torch.zeros(height, width, 4, device="cpu")
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048*2048).squeeze()
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048 * 2048).squeeze()
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_np)
@@ -696,29 +691,23 @@ class OpenAIChatNode(IO.ComfyNode):
)
@classmethod
def get_message_content_from_response(
cls, response: OpenAIResponse
) -> list[OutputContent]:
def get_message_content_from_response(cls, response: OpenAIResponse) -> list[OutputContent]:
"""Extract message content from the API response."""
for output in response.output:
if output.root.type == "message":
return output.root.content
if output.type == "message":
return output.content
raise TypeError("No output message found in response")
@classmethod
def get_text_from_message_content(
cls, message_content: list[OutputContent]
) -> str:
def get_text_from_message_content(cls, message_content: list[OutputContent]) -> str:
"""Extract text content from message content."""
for content_item in message_content:
if content_item.root.type == "output_text":
return str(content_item.root.text)
if content_item.type == "output_text":
return str(content_item.text)
return "No text output found in response"
@classmethod
def tensor_to_input_image_content(
cls, image: torch.Tensor, detail_level: Detail = "auto"
) -> InputImageContent:
def tensor_to_input_image_content(cls, image: torch.Tensor, detail_level: str = "auto") -> InputImageContent:
"""Convert a tensor to an input image content object."""
return InputImageContent(
detail=detail_level,
@@ -732,9 +721,9 @@ class OpenAIChatNode(IO.ComfyNode):
prompt: str,
image: torch.Tensor | None = None,
files: list[InputFileContent] | None = None,
) -> InputMessageContentList:
) -> list[InputTextContent | InputImageContent | InputFileContent]:
"""Create a list of input message contents from prompt and optional image."""
content_list: list[InputContent | InputTextContent | InputImageContent | InputFileContent] = [
content_list: list[InputTextContent | InputImageContent | InputFileContent] = [
InputTextContent(text=prompt, type="input_text"),
]
if image is not None:
@@ -746,13 +735,9 @@ class OpenAIChatNode(IO.ComfyNode):
type="input_image",
)
)
if files is not None:
content_list.extend(files)
return InputMessageContentList(
root=content_list,
)
return content_list
@classmethod
async def execute(
@@ -762,7 +747,7 @@ class OpenAIChatNode(IO.ComfyNode):
model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value,
images: torch.Tensor | None = None,
files: list[InputFileContent] | None = None,
advanced_options: CreateModelResponseProperties | None = None,
advanced_options: ModelResponseProperties | None = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
@@ -773,36 +758,28 @@ class OpenAIChatNode(IO.ComfyNode):
response_model=OpenAIResponse,
data=OpenAICreateResponse(
input=[
Item(
root=InputMessage(
content=cls.create_input_message_contents(
prompt, images, files
),
role="user",
)
InputMessage(
content=cls.create_input_message_contents(prompt, images, files),
role="user",
),
],
store=True,
stream=False,
model=model,
previous_response_id=None,
**(
advanced_options.model_dump(exclude_none=True)
if advanced_options
else {}
),
**(advanced_options.model_dump(exclude_none=True) if advanced_options else {}),
),
)
response_id = create_response.id
# Get result output
result_response = await poll_op(
cls,
ApiEndpoint(path=f"{RESPONSES_ENDPOINT}/{response_id}"),
response_model=OpenAIResponse,
status_extractor=lambda response: response.status,
completed_statuses=["incomplete", "completed"]
)
cls,
ApiEndpoint(path=f"{RESPONSES_ENDPOINT}/{response_id}"),
response_model=OpenAIResponse,
status_extractor=lambda response: response.status,
completed_statuses=["incomplete", "completed"],
)
return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response)))
@@ -923,7 +900,7 @@ class OpenAIChatConfig(IO.ComfyNode):
remove depending on model choice.
"""
return IO.NodeOutput(
CreateModelResponseProperties(
ModelResponseProperties(
instructions=instructions,
truncation=truncation,
max_output_tokens=max_output_tokens,

View File

@@ -1,7 +1,7 @@
import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.pixverse_api import (
from comfy_api_nodes.apis.pixverse import (
PixverseTextVideoRequest,
PixverseImageVideoRequest,
PixverseTransitionVideoRequest,

View File

@@ -8,7 +8,7 @@ from typing_extensions import override
from comfy.utils import ProgressBar
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.recraft_api import (
from comfy_api_nodes.apis.recraft import (
RecraftColor,
RecraftColorChain,
RecraftControls,

View File

@@ -14,7 +14,7 @@ from typing import Optional
from io import BytesIO
from typing_extensions import override
from PIL import Image
from comfy_api_nodes.apis.rodin_api import (
from comfy_api_nodes.apis.rodin import (
Rodin3DGenerateRequest,
Rodin3DGenerateResponse,
Rodin3DCheckStatusRequest,

View File

@@ -16,7 +16,7 @@ from enum import Enum
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis import (
from comfy_api_nodes.apis.runway import (
RunwayImageToVideoRequest,
RunwayImageToVideoResponse,
RunwayTaskStatusResponse as TaskStatusResponse,

View File

@@ -3,7 +3,7 @@ from typing import Optional
from typing_extensions import override
from comfy_api.latest import ComfyExtension, Input, IO
from comfy_api_nodes.apis.stability_api import (
from comfy_api_nodes.apis.stability import (
StabilityUpscaleConservativeRequest,
StabilityUpscaleCreativeRequest,
StabilityAsyncResponse,

View File

@@ -5,7 +5,24 @@ import aiohttp
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis import topaz_api
from comfy_api_nodes.apis.topaz import (
CreateVideoRequest,
CreateVideoRequestSource,
CreateVideoResponse,
ImageAsyncTaskResponse,
ImageDownloadResponse,
ImageEnhanceRequest,
ImageStatusResponse,
OutputInformationVideo,
Resolution,
VideoAcceptResponse,
VideoCompleteUploadRequest,
VideoCompleteUploadRequestPart,
VideoCompleteUploadResponse,
VideoEnhancementFilter,
VideoFrameInterpolationFilter,
VideoStatusResponse,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
@@ -153,13 +170,13 @@ class TopazImageEnhance(IO.ComfyNode):
if get_number_of_images(image) != 1:
raise ValueError("Only one input image is supported.")
download_url = await upload_images_to_comfyapi(
cls, image, max_images=1, mime_type="image/png", total_pixels=4096*4096
cls, image, max_images=1, mime_type="image/png", total_pixels=4096 * 4096
)
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"),
response_model=topaz_api.ImageAsyncTaskResponse,
data=topaz_api.ImageEnhanceRequest(
response_model=ImageAsyncTaskResponse,
data=ImageEnhanceRequest(
model=model,
prompt=prompt,
subject_detection=subject_detection,
@@ -181,7 +198,7 @@ class TopazImageEnhance(IO.ComfyNode):
await poll_op(
cls,
poll_endpoint=ApiEndpoint(path=f"/proxy/topaz/image/v1/status/{initial_response.process_id}"),
response_model=topaz_api.ImageStatusResponse,
response_model=ImageStatusResponse,
status_extractor=lambda x: x.status,
progress_extractor=lambda x: getattr(x, "progress", 0),
price_extractor=lambda x: x.credits * 0.08,
@@ -193,7 +210,7 @@ class TopazImageEnhance(IO.ComfyNode):
results = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/topaz/image/v1/download/{initial_response.process_id}"),
response_model=topaz_api.ImageDownloadResponse,
response_model=ImageDownloadResponse,
monitor_progress=False,
)
return IO.NodeOutput(await download_url_to_image_tensor(results.download_url))
@@ -331,7 +348,7 @@ class TopazVideoEnhance(IO.ComfyNode):
if target_height % 2 != 0:
target_height += 1
filters.append(
topaz_api.VideoEnhancementFilter(
VideoEnhancementFilter(
model=UPSCALER_MODELS_MAP[upscaler_model],
creativity=(upscaler_creativity if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None),
isOptimizedMode=(True if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None),
@@ -340,7 +357,7 @@ class TopazVideoEnhance(IO.ComfyNode):
if interpolation_enabled:
target_frame_rate = interpolation_frame_rate
filters.append(
topaz_api.VideoFrameInterpolationFilter(
VideoFrameInterpolationFilter(
model=interpolation_model,
slowmo=interpolation_slowmo,
fps=interpolation_frame_rate,
@@ -351,19 +368,19 @@ class TopazVideoEnhance(IO.ComfyNode):
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/topaz/video/", method="POST"),
response_model=topaz_api.CreateVideoResponse,
data=topaz_api.CreateVideoRequest(
source=topaz_api.CreateCreateVideoRequestSource(
response_model=CreateVideoResponse,
data=CreateVideoRequest(
source=CreateVideoRequestSource(
container="mp4",
size=get_fs_object_size(src_video_stream),
duration=int(duration_sec),
frameCount=video.get_frame_count(),
frameRate=src_frame_rate,
resolution=topaz_api.Resolution(width=src_width, height=src_height),
resolution=Resolution(width=src_width, height=src_height),
),
filters=filters,
output=topaz_api.OutputInformationVideo(
resolution=topaz_api.Resolution(width=target_width, height=target_height),
output=OutputInformationVideo(
resolution=Resolution(width=target_width, height=target_height),
frameRate=target_frame_rate,
audioCodec="AAC",
audioTransfer="Copy",
@@ -379,7 +396,7 @@ class TopazVideoEnhance(IO.ComfyNode):
path=f"/proxy/topaz/video/{initial_res.requestId}/accept",
method="PATCH",
),
response_model=topaz_api.VideoAcceptResponse,
response_model=VideoAcceptResponse,
wait_label="Preparing upload",
final_label_on_success="Upload started",
)
@@ -402,10 +419,10 @@ class TopazVideoEnhance(IO.ComfyNode):
path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload",
method="PATCH",
),
response_model=topaz_api.VideoCompleteUploadResponse,
data=topaz_api.VideoCompleteUploadRequest(
response_model=VideoCompleteUploadResponse,
data=VideoCompleteUploadRequest(
uploadResults=[
topaz_api.VideoCompleteUploadRequestPart(
VideoCompleteUploadRequestPart(
partNum=1,
eTag=upload_etag,
),
@@ -417,7 +434,7 @@ class TopazVideoEnhance(IO.ComfyNode):
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"),
response_model=topaz_api.VideoStatusResponse,
response_model=VideoStatusResponse,
status_extractor=lambda x: x.status,
progress_extractor=lambda x: getattr(x, "progress", 0),
price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None),

View File

@@ -5,7 +5,7 @@ import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.tripo_api import (
from comfy_api_nodes.apis.tripo import (
TripoAnimateRetargetRequest,
TripoAnimateRigRequest,
TripoConvertModelRequest,

View File

@@ -4,7 +4,7 @@ from io import BytesIO
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis.veo_api import (
from comfy_api_nodes.apis.veo import (
VeoGenVidPollRequest,
VeoGenVidPollResponse,
VeoGenVidRequest,

View File

@@ -703,7 +703,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode):
"subjects",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("reference_images"),
names=["subject1", "subject2", "subject3"],
names=["subject1", "subject2", "subject3", "subject4", "subject5", "subject6", "subject7"],
min=1,
),
tooltip="For each subject, provide up to 3 reference images (7 images total across all subjects). "
@@ -738,7 +738,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode):
control_after_generate=True,
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "4:3", "3:4", "1:1"]),
IO.Combo.Input("resolution", options=["720p"]),
IO.Combo.Input("resolution", options=["720p", "1080p"]),
IO.Combo.Input(
"movement_amplitude",
options=["auto", "small", "medium", "large"],

View File

@@ -0,0 +1,178 @@
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.wavespeed import (
FlashVSRRequest,
TaskCreatedResponse,
TaskResultResponse,
SeedVR2ImageRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_video_output,
poll_op,
sync_op,
upload_video_to_comfyapi,
validate_container_format_is_mp4,
validate_video_duration,
upload_images_to_comfyapi,
get_number_of_images,
download_url_to_image_tensor,
)
class WavespeedFlashVSRNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="WavespeedFlashVSRNode",
display_name="FlashVSR Video Upscale",
category="api node/video/WaveSpeed",
description="Fast, high-quality video upscaler that "
"boosts resolution and restores clarity for low-resolution or blurry footage.",
inputs=[
IO.Video.Input("video"),
IO.Combo.Input("target_resolution", options=["720p", "1080p", "2K", "4K"]),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["target_resolution"]),
expr="""
(
$price_for_1sec := {"720p": 0.012, "1080p": 0.018, "2k": 0.024, "4k": 0.032};
{
"type":"usd",
"usd": $lookup($price_for_1sec, widgets.target_resolution),
"format":{"suffix": "/second", "approximate": true}
}
)
""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
target_resolution: str,
) -> IO.NodeOutput:
validate_container_format_is_mp4(video)
validate_video_duration(video, min_duration=5, max_duration=60 * 10)
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/wavespeed/api/v3/wavespeed-ai/flashvsr", method="POST"),
response_model=TaskCreatedResponse,
data=FlashVSRRequest(
target_resolution=target_resolution.lower(),
video=await upload_video_to_comfyapi(cls, video),
duration=video.get_duration(),
),
)
if initial_res.code != 200:
raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}")
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"),
response_model=TaskResultResponse,
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
poll_interval=10.0,
max_poll_attempts=480,
)
if final_response.code != 200:
raise ValueError(
f"Task processing failed with code={final_response.code} and message={final_response.message}"
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.outputs[0]))
class WavespeedImageUpscaleNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="WavespeedImageUpscaleNode",
display_name="WaveSpeed Image Upscale",
category="api node/image/WaveSpeed",
description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.",
inputs=[
IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]),
IO.Image.Input("image"),
IO.Combo.Input("target_resolution", options=["2K", "4K", "8K"]),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$prices := {"seedvr2": 0.01, "ultimate": 0.06};
{"type":"usd", "usd": $lookup($prices, widgets.model)}
)
""",
),
)
@classmethod
async def execute(
cls,
model: str,
image: Input.Image,
target_resolution: str,
) -> IO.NodeOutput:
if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.")
if model == "SeedVR2":
model_path = "seedvr2/image"
else:
model_path = "ultimate-image-upscaler"
initial_res = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/wavespeed-ai/{model_path}", method="POST"),
response_model=TaskCreatedResponse,
data=SeedVR2ImageRequest(
target_resolution=target_resolution.lower(),
image=(await upload_images_to_comfyapi(cls, image, max_images=1))[0],
),
)
if initial_res.code != 200:
raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}")
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"),
response_model=TaskResultResponse,
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
poll_interval=10.0,
max_poll_attempts=480,
)
if final_response.code != 200:
raise ValueError(
f"Task processing failed with code={final_response.code} and message={final_response.message}"
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0]))
class WavespeedExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
WavespeedFlashVSRNode,
WavespeedImageUpscaleNode,
]
async def comfy_entrypoint() -> WavespeedExtension:
return WavespeedExtension()

View File

@@ -1,10 +0,0 @@
# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes.
# This is used for development purposes to generate stubs for unreleased API endpoints.
apis:
filter:
root: openapi.yaml
decorators:
filter-in:
property: tags
value: ['API Nodes']
matchStrategy: all

View File

@@ -1,10 +0,0 @@
# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes.
apis:
filter:
root: openapi.yaml
decorators:
filter-in:
property: tags
value: ['API Nodes', 'Released']
matchStrategy: all

View File

@@ -11,6 +11,7 @@ from .conversions import (
audio_input_to_mp3,
audio_to_base64_string,
bytesio_to_image_tensor,
convert_mask_to_image,
downscale_image_tensor,
image_tensor_pair_to_batch,
pil_to_bytesio,
@@ -72,6 +73,7 @@ __all__ = [
"audio_input_to_mp3",
"audio_to_base64_string",
"bytesio_to_image_tensor",
"convert_mask_to_image",
"downscale_image_tensor",
"image_tensor_pair_to_batch",
"pil_to_bytesio",

View File

@@ -451,6 +451,12 @@ def resize_mask_to_image(
return mask
def convert_mask_to_image(mask: Input.Image) -> torch.Tensor:
"""Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image."""
mask = mask.unsqueeze(-1)
return torch.cat([mask] * 3, dim=-1)
def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string."""
with open(filepath, "rb") as f:

View File

@@ -29,8 +29,10 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
do_easycache = easycache.should_do_easycache(sigmas)
if do_easycache:
easycache.check_metadata(x)
# if there isn't a cache diff for current conds, we cannot skip this step
can_apply_cache_diff = easycache.can_apply_cache_diff(uuids)
# if first cond marked this step for skipping, skip it and use appropriate cached values
if easycache.skip_current_step:
if easycache.skip_current_step and can_apply_cache_diff:
if easycache.verbose:
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
return easycache.apply_cache_diff(x, uuids)
@@ -44,7 +46,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
easycache.cumulative_change_rate += approx_output_change_rate
if easycache.cumulative_change_rate < easycache.reuse_threshold:
if easycache.cumulative_change_rate < easycache.reuse_threshold and can_apply_cache_diff:
if easycache.verbose:
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
# other conds should also skip this step, and instead use their cached values
@@ -240,6 +242,9 @@ class EasyCacheHolder:
return to_return.clone()
return to_return
def can_apply_cache_diff(self, uuids: list[UUID]) -> bool:
return all(uuid in self.uuid_cache_diffs for uuid in uuids)
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
if self.first_cond_uuid in uuids:
self.total_steps_skipped += 1

View File

@@ -24,7 +24,7 @@ class Load3D(IO.ComfyNode):
files = [
normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl', '.spz', '.splat', '.ply', '.ksplat'}
]
return IO.Schema(
node_id="Load3D",

View File

@@ -7,6 +7,7 @@ import comfy.model_management
import comfy.ldm.common_dit
import comfy.latent_formats
import comfy.ldm.lumina.controlnet
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
class BlockWiseControlBlock(torch.nn.Module):
@@ -257,6 +258,14 @@ class ModelPatchLoader:
if torch.count_nonzero(ref_weight) == 0:
config['broken'] = True
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
elif "audio_proj.proj1.weight" in sd:
model = MultiTalkModelPatch(
audio_window=5, context_tokens=32, vae_scale=4,
in_dim=sd["blocks.0.audio_cross_attn.proj.weight"].shape[0],
intermediate_dim=sd["audio_proj.proj1.weight"].shape[0],
out_dim=sd["audio_proj.norm.weight"].shape[0],
device=comfy.model_management.unet_offload_device(),
operations=comfy.ops.manual_cast)
model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
@@ -524,6 +533,38 @@ class USOStyleReference:
return (model_patched,)
class MultiTalkModelPatch(torch.nn.Module):
def __init__(
self,
audio_window: int = 5,
intermediate_dim: int = 512,
in_dim: int = 5120,
out_dim: int = 768,
context_tokens: int = 32,
vae_scale: int = 4,
num_layers: int = 40,
device=None, dtype=None, operations=None
):
super().__init__()
self.audio_proj = MultiTalkAudioProjModel(
seq_len=audio_window,
seq_len_vf=audio_window+vae_scale-1,
intermediate_dim=intermediate_dim,
out_dim=out_dim,
context_tokens=context_tokens,
device=device,
dtype=dtype,
operations=operations
)
self.blocks = torch.nn.ModuleList(
[
WanMultiTalkAttentionBlock(in_dim, out_dim, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
]
)
NODE_CLASS_MAPPINGS = {
"ModelPatchLoader": ModelPatchLoader,
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,

View File

@@ -550,6 +550,7 @@ class BatchImagesNode(io.ComfyNode):
node_id="BatchImagesNode",
display_name="Batch Images",
category="image",
search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"],
inputs=[
io.Autogrow.Input("images", template=autogrow_template)
],

View File

@@ -16,6 +16,7 @@ class PreviewAny():
OUTPUT_NODE = True
CATEGORY = "utils"
SEARCH_ALIASES = ["preview", "show", "display", "view", "show text", "display text", "preview text", "show output", "inspect", "debug"]
def main(self, source=None):
value = 'None'

View File

@@ -11,6 +11,7 @@ class StringConcatenate(io.ComfyNode):
node_id="StringConcatenate",
display_name="Concatenate",
category="utils/string",
search_aliases=["text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"],
inputs=[
io.String.Input("string_a", multiline=True),
io.String.Input("string_b", multiline=True),

View File

@@ -53,6 +53,7 @@ class ImageUpscaleWithModel(io.ComfyNode):
node_id="ImageUpscaleWithModel",
display_name="Upscale Image (using Model)",
category="image/upscaling",
search_aliases=["upscale", "upscaler", "upsc", "enlarge image", "super resolution", "hires", "superres", "increase resolution"],
inputs=[
io.UpscaleModel.Input("upscale_model"),
io.Image.Input("image"),

View File

@@ -159,6 +159,29 @@ class GetVideoComponents(io.ComfyNode):
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
class VideoSlice(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VideoSlice",
display_name="Video Slice",
category="image/video",
description="Extract a range of frames from a video.",
inputs=[
io.Video.Input("video", tooltip="The video to slice."),
io.Int.Input("start_frame", default=0, min=0, tooltip="The frame index to start from (0-indexed)."),
io.Int.Input("frame_count", default=1, min=1, tooltip="Number of frames to extract."),
],
outputs=[
io.Video.Output(tooltip="The sliced video."),
],
)
@classmethod
def execute(cls, video: Input.Video, start_frame: int, frame_count: int) -> io.NodeOutput:
return io.NodeOutput(video.sliced(start_frame, frame_count))
class LoadVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
@@ -206,6 +229,7 @@ class VideoExtension(ComfyExtension):
SaveVideo,
CreateVideo,
GetVideoComponents,
VideoSlice,
LoadVideo,
]

View File

@@ -8,9 +8,10 @@ import comfy.latent_formats
import comfy.clip_vision
import json
import numpy as np
from typing import Tuple
from typing import Tuple, TypedDict
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import logging
class WanImageToVideo(io.ComfyNode):
@classmethod
@@ -1288,6 +1289,171 @@ class Wan22ImageToVideoLatent(io.ComfyNode):
return io.NodeOutput(out_latent)
from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, MultiTalkGetAttnMapPatch, project_audio_features
class WanInfiniteTalkToVideo(io.ComfyNode):
class DCValues(TypedDict):
mode: str
audio_encoder_output_2: io.AudioEncoderOutput.Type
mask: io.Mask.Type
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanInfiniteTalkToVideo",
category="conditioning/video_models",
inputs=[
io.DynamicCombo.Input("mode", options=[
io.DynamicCombo.Option("single_speaker", []),
io.DynamicCombo.Option("two_speakers", [
io.AudioEncoderOutput.Input("audio_encoder_output_2", optional=True),
io.Mask.Input("mask_1", optional=True, tooltip="Mask for the first speaker, required if using two audio inputs."),
io.Mask.Input("mask_2", optional=True, tooltip="Mask for the second speaker, required if using two audio inputs."),
]),
]),
io.Model.Input("model"),
io.ModelPatch.Input("model_patch"),
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("start_image", optional=True),
io.AudioEncoderOutput.Input("audio_encoder_output_1"),
io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context."),
io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01),
io.Image.Input("previous_frames", optional=True),
],
outputs=[
io.Model.Output(display_name="model"),
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
io.Int.Output(display_name="trim_image"),
],
)
@classmethod
def execute(cls, mode: DCValues, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count,
start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput:
if previous_frames is not None and previous_frames.shape[0] < motion_frame_count:
raise ValueError("Not enough previous frames provided.")
if mode["mode"] == "two_speakers":
audio_encoder_output_2 = mode["audio_encoder_output_2"]
mask_1 = mode["mask_1"]
mask_2 = mode["mask_2"]
if audio_encoder_output_2 is not None:
if mask_1 is None or mask_2 is None:
raise ValueError("Masks must be provided if two audio encoder outputs are used.")
ref_masks = None
if mask_1 is not None and mask_2 is not None:
if audio_encoder_output_2 is None:
raise ValueError("Second audio encoder output must be provided if two masks are used.")
ref_masks = torch.cat([mask_1, mask_2])
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
image[:start_image.shape[0]] = start_image
concat_latent_image = vae.encode(image[:, :, :, :3])
concat_mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
concat_mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
model_patched = model.clone()
encoded_audio_list = []
seq_lengths = []
for audio_encoder_output in [audio_encoder_output_1, audio_encoder_output_2]:
if audio_encoder_output is None:
continue
all_layers = audio_encoder_output["encoded_audio_all_layers"]
encoded_audio = torch.stack(all_layers, dim=0).squeeze(1)[1:] # shape: [num_layers, T, 512]
encoded_audio = linear_interpolation(encoded_audio, input_fps=50, output_fps=25).movedim(0, 1) # shape: [T, num_layers, 512]
encoded_audio_list.append(encoded_audio)
seq_lengths.append(encoded_audio.shape[0])
# Pad / combine depending on multi_audio_type
multi_audio_type = "add"
if len(encoded_audio_list) > 1:
if multi_audio_type == "para":
max_len = max(seq_lengths)
padded = []
for emb in encoded_audio_list:
if emb.shape[0] < max_len:
pad = torch.zeros(max_len - emb.shape[0], *emb.shape[1:], dtype=emb.dtype)
emb = torch.cat([emb, pad], dim=0)
padded.append(emb)
encoded_audio_list = padded
elif multi_audio_type == "add":
total_len = sum(seq_lengths)
full_list = []
offset = 0
for emb, seq_len in zip(encoded_audio_list, seq_lengths):
full = torch.zeros(total_len, *emb.shape[1:], dtype=emb.dtype)
full[offset:offset+seq_len] = emb
full_list.append(full)
offset += seq_len
encoded_audio_list = full_list
token_ref_target_masks = None
if ref_masks is not None:
token_ref_target_masks = torch.nn.functional.interpolate(
ref_masks.unsqueeze(0), size=(latent.shape[-2] // 2, latent.shape[-1] // 2), mode='nearest')[0]
token_ref_target_masks = (token_ref_target_masks > 0).view(token_ref_target_masks.shape[0], -1)
# when extending from previous frames
if previous_frames is not None:
motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
frame_offset = previous_frames.shape[0] - motion_frame_count
audio_start = frame_offset
audio_end = audio_start + length
logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}")
motion_frames_latent = vae.encode(motion_frames[:, :, :, :3])
trim_image = motion_frame_count
else:
audio_start = trim_image = 0
audio_end = length
motion_frames_latent = concat_latent_image[:, :, :1]
audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype())
model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed
# add outer sample wrapper
model_patched.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.OUTER_SAMPLE,
"infinite_talk_outer_sample",
InfiniteTalkOuterSampleWrapper(
motion_frames_latent,
model_patch,
is_extend=previous_frames is not None,
))
# add cross-attention patch
model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "attn2_patch")
if token_ref_target_masks is not None:
model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "attn1_patch")
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
class WanExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@@ -1307,6 +1473,7 @@ class WanExtension(ComfyExtension):
WanHuMoImageToVideo,
WanAnimateToVideo,
Wan22ImageToVideoLatent,
WanInfiniteTalkToVideo,
]
async def comfy_entrypoint() -> WanExtension:

View File

@@ -0,0 +1,88 @@
import node_helpers
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import math
import comfy.utils
class TextEncodeZImageOmni(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TextEncodeZImageOmni",
category="advanced/conditioning",
is_experimental=True,
inputs=[
io.Clip.Input("clip"),
io.ClipVision.Input("image_encoder", optional=True),
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.Boolean.Input("auto_resize_images", default=True),
io.Vae.Input("vae", optional=True),
io.Image.Input("image1", optional=True),
io.Image.Input("image2", optional=True),
io.Image.Input("image3", optional=True),
],
outputs=[
io.Conditioning.Output(),
],
)
@classmethod
def execute(cls, clip, prompt, image_encoder=None, auto_resize_images=True, vae=None, image1=None, image2=None, image3=None) -> io.NodeOutput:
ref_latents = []
images = list(filter(lambda a: a is not None, [image1, image2, image3]))
prompt_list = []
template = None
if len(images) > 0:
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
prompt_list += ["<|vision_end|><|vision_start|>"] * (len(images) - 1)
prompt_list += ["<|vision_end|><|im_end|>"]
template = "<|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"
encoded_images = []
for i, image in enumerate(images):
if image_encoder is not None:
encoded_images.append(image_encoder.encode_image(image))
if vae is not None:
if auto_resize_images:
samples = image.movedim(-1, 1)
total = int(1024 * 1024)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by / 8.0) * 8
height = round(samples.shape[2] * scale_by / 8.0) * 8
image = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1)
ref_latents.append(vae.encode(image))
tokens = clip.tokenize(prompt, llama_template=template)
conditioning = clip.encode_from_tokens_scheduled(tokens)
extra_text_embeds = []
for p in prompt_list:
tokens = clip.tokenize(p, llama_template="{}")
text_embeds = clip.encode_from_tokens_scheduled(tokens)
extra_text_embeds.append(text_embeds[0][0])
if len(ref_latents) > 0:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
if len(encoded_images) > 0:
conditioning = node_helpers.conditioning_set_values(conditioning, {"clip_vision_outputs": encoded_images}, append=True)
if len(extra_text_embeds) > 0:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents_text_embeds": extra_text_embeds}, append=True)
return io.NodeOutput(conditioning)
class ZImageExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeZImageOmni,
]
async def comfy_entrypoint() -> ZImageExtension:
return ZImageExtension()

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.9.2"
__version__ = "0.10.0"

View File

@@ -11,7 +11,7 @@ import logging
default_preview_method = args.preview_method
MAX_PREVIEW_RESOLUTION = args.preview_size
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
def preview_to_image(latent_image, do_scale=True):
if do_scale:

View File

@@ -5,6 +5,7 @@ import torch
import os
import sys
import json
import glob
import hashlib
import inspect
import traceback
@@ -69,6 +70,7 @@ class CLIPTextEncode(ComfyNodeABC):
CATEGORY = "conditioning"
DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
SEARCH_ALIASES = ["text", "prompt", "text prompt", "positive prompt", "negative prompt", "encode text", "text encoder", "encode prompt"]
def encode(self, clip, text):
if clip is None:
@@ -85,6 +87,7 @@ class ConditioningCombine:
FUNCTION = "combine"
CATEGORY = "conditioning"
SEARCH_ALIASES = ["combine", "merge conditioning", "combine prompts", "merge prompts", "mix prompts", "add prompt"]
def combine(self, conditioning_1, conditioning_2):
return (conditioning_1 + conditioning_2, )
@@ -293,6 +296,7 @@ class VAEDecode:
CATEGORY = "latent"
DESCRIPTION = "Decodes latent images back into pixel space images."
SEARCH_ALIASES = ["decode", "decode latent", "latent to image", "render latent"]
def decode(self, vae, samples):
latent = samples["samples"]
@@ -345,6 +349,7 @@ class VAEEncode:
FUNCTION = "encode"
CATEGORY = "latent"
SEARCH_ALIASES = ["encode", "encode image", "image to latent"]
def encode(self, vae, pixels):
t = vae.encode(pixels)
@@ -580,6 +585,7 @@ class CheckpointLoaderSimple:
CATEGORY = "loaders"
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
SEARCH_ALIASES = ["load model", "checkpoint", "model loader", "load checkpoint", "ckpt", "model"]
def load_checkpoint(self, ckpt_name):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
@@ -666,6 +672,7 @@ class LoraLoader:
CATEGORY = "loaders"
DESCRIPTION = "LoRAs are used to modify diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together."
SEARCH_ALIASES = ["lora", "load lora", "apply lora", "lora loader", "lora model"]
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
if strength_model == 0 and strength_clip == 0:
@@ -700,7 +707,7 @@ class LoraLoaderModelOnly(LoraLoader):
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
class VAELoader:
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
@staticmethod
def vae_list(s):
@@ -813,6 +820,7 @@ class ControlNetLoader:
FUNCTION = "load_controlnet"
CATEGORY = "loaders"
SEARCH_ALIASES = ["controlnet", "control net", "cn", "load controlnet", "controlnet loader"]
def load_controlnet(self, control_net_name):
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
@@ -889,6 +897,7 @@ class ControlNetApplyAdvanced:
FUNCTION = "apply_controlnet"
CATEGORY = "conditioning/controlnet"
SEARCH_ALIASES = ["controlnet", "apply controlnet", "use controlnet", "control net"]
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]):
if strength == 0:
@@ -1199,6 +1208,7 @@ class EmptyLatentImage:
CATEGORY = "latent"
DESCRIPTION = "Create a new batch of empty latent images to be denoised via sampling."
SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"]
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
@@ -1539,6 +1549,7 @@ class KSampler:
CATEGORY = "sampling"
DESCRIPTION = "Uses the provided model, positive and negative conditioning to denoise the latent image."
SEARCH_ALIASES = ["sampler", "sample", "generate", "denoise", "diffuse", "txt2img", "img2img"]
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
@@ -1603,6 +1614,7 @@ class SaveImage:
CATEGORY = "image"
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
SEARCH_ALIASES = ["save", "save image", "export image", "output image", "write image", "download"]
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
@@ -1639,6 +1651,8 @@ class PreviewImage(SaveImage):
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
self.compress_level = 1
SEARCH_ALIASES = ["preview", "preview image", "show image", "view image", "display image", "image viewer"]
@classmethod
def INPUT_TYPES(s):
return {"required":
@@ -1657,6 +1671,7 @@ class LoadImage:
}
CATEGORY = "image"
SEARCH_ALIASES = ["load image", "open image", "import image", "image input", "upload image", "read image", "image loader"]
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image"
@@ -1809,6 +1824,7 @@ class ImageScale:
FUNCTION = "upscale"
CATEGORY = "image/upscaling"
SEARCH_ALIASES = ["resize", "resize image", "scale image", "image resize", "zoom", "zoom in", "change size"]
def upscale(self, image, upscale_method, width, height, crop):
if width == 0 and height == 0:
@@ -2372,6 +2388,7 @@ async def init_builtin_extra_nodes():
"nodes_kandinsky5.py",
"nodes_wanmove.py",
"nodes_image_compare.py",
"nodes_zimage.py",
]
import_failed = []
@@ -2384,38 +2401,12 @@ async def init_builtin_extra_nodes():
async def init_builtin_api_nodes():
api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes")
api_nodes_files = [
"nodes_ideogram.py",
"nodes_openai.py",
"nodes_minimax.py",
"nodes_veo2.py",
"nodes_kling.py",
"nodes_bfl.py",
"nodes_bytedance.py",
"nodes_ltxv.py",
"nodes_luma.py",
"nodes_recraft.py",
"nodes_pixverse.py",
"nodes_stability.py",
"nodes_runway.py",
"nodes_sora.py",
"nodes_topaz.py",
"nodes_tripo.py",
"nodes_meshy.py",
"nodes_moonvalley.py",
"nodes_rodin.py",
"nodes_gemini.py",
"nodes_vidu.py",
"nodes_wan.py",
]
if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):
return api_nodes_files
api_nodes_files = sorted(glob.glob(os.path.join(api_nodes_dir, "nodes_*.py")))
import_failed = []
for node_file in api_nodes_files:
if not await load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"):
import_failed.append(node_file)
if not await load_custom_node(node_file, module_parent="comfy_api_nodes"):
import_failed.append(os.path.basename(node_file))
return import_failed

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.9.2"
version = "0.10.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.36.14
comfyui-workflow-templates==0.8.11
comfyui-frontend-package==1.37.11
comfyui-workflow-templates==0.8.15
comfyui-embedded-docs==0.4.0
torch
torchsde
@@ -21,7 +21,7 @@ psutil
alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.6
comfy-kitchen>=0.2.7
#non essential dependencies:
kornia>=0.7.1

View File

@@ -682,6 +682,8 @@ class PromptServer():
if hasattr(obj_class, 'API_NODE'):
info['api_node'] = obj_class.API_NODE
info['search_aliases'] = getattr(obj_class, 'SEARCH_ALIASES', [])
return info
@routes.get("/object_info")

View File

@@ -1,297 +0,0 @@
from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field
from comfy.comfy_types.node_typing import IO
from comfy_api_nodes.mapper_utils import model_field_to_node_input
def test_model_field_to_float_input():
"""Tests mapping a float field with constraints."""
class ModelWithFloatField(BaseModel):
cfg_scale: Optional[float] = Field(
default=0.5,
description="Flexibility in video generation",
ge=0.0,
le=1.0,
multiple_of=0.001,
)
expected_output = (
IO.FLOAT,
{
"default": 0.5,
"tooltip": "Flexibility in video generation",
"min": 0.0,
"max": 1.0,
"step": 0.001,
},
)
actual_output = model_field_to_node_input(
IO.FLOAT, ModelWithFloatField, "cfg_scale"
)
assert actual_output[0] == expected_output[0]
assert actual_output[1] == expected_output[1]
def test_model_field_to_float_input_no_constraints():
"""Tests mapping a float field with no constraints."""
class ModelWithFloatField(BaseModel):
cfg_scale: Optional[float] = Field(default=0.5)
expected_output = (
IO.FLOAT,
{
"default": 0.5,
},
)
actual_output = model_field_to_node_input(
IO.FLOAT, ModelWithFloatField, "cfg_scale"
)
assert actual_output[0] == expected_output[0]
assert actual_output[1] == expected_output[1]
def test_model_field_to_int_input():
"""Tests mapping an int field with constraints."""
class ModelWithIntField(BaseModel):
num_frames: Optional[int] = Field(
default=10,
description="Number of frames to generate",
ge=1,
le=100,
multiple_of=1,
)
expected_output = (
IO.INT,
{
"default": 10,
"tooltip": "Number of frames to generate",
"min": 1,
"max": 100,
"step": 1,
},
)
actual_output = model_field_to_node_input(IO.INT, ModelWithIntField, "num_frames")
assert actual_output[0] == expected_output[0]
assert actual_output[1] == expected_output[1]
def test_model_field_to_string_input():
"""Tests mapping a string field."""
class ModelWithStringField(BaseModel):
prompt: Optional[str] = Field(
default="A beautiful sunset over a calm ocean",
description="A prompt for the video generation",
)
expected_output = (
IO.STRING,
{
"default": "A beautiful sunset over a calm ocean",
"tooltip": "A prompt for the video generation",
},
)
actual_output = model_field_to_node_input(IO.STRING, ModelWithStringField, "prompt")
assert actual_output[0] == expected_output[0]
assert actual_output[1] == expected_output[1]
def test_model_field_to_string_input_multiline():
"""Tests mapping a string field."""
class ModelWithStringField(BaseModel):
prompt: Optional[str] = Field(
default="A beautiful sunset over a calm ocean",
description="A prompt for the video generation",
)
expected_output = (
IO.STRING,
{
"default": "A beautiful sunset over a calm ocean",
"tooltip": "A prompt for the video generation",
"multiline": True,
},
)
actual_output = model_field_to_node_input(
IO.STRING, ModelWithStringField, "prompt", multiline=True
)
assert actual_output[0] == expected_output[0]
assert actual_output[1] == expected_output[1]
def test_model_field_to_combo_input():
"""Tests mapping a combo field."""
class MockEnum(str, Enum):
option_1 = "option 1"
option_2 = "option 2"
option_3 = "option 3"
class ModelWithComboField(BaseModel):
model_name: Optional[MockEnum] = Field("option 1", description="Model Name")
expected_output = (
IO.COMBO,
{
"options": ["option 1", "option 2", "option 3"],
"default": "option 1",
"tooltip": "Model Name",
},
)
actual_output = model_field_to_node_input(
IO.COMBO, ModelWithComboField, "model_name", enum_type=MockEnum
)
assert actual_output[0] == expected_output[0]
assert actual_output[1] == expected_output[1]
def test_model_field_to_combo_input_no_options():
"""Tests mapping a combo field with no options."""
class ModelWithComboField(BaseModel):
model_name: Optional[str] = Field(description="Model Name")
expected_output = (
IO.COMBO,
{
"tooltip": "Model Name",
},
)
actual_output = model_field_to_node_input(
IO.COMBO, ModelWithComboField, "model_name"
)
assert actual_output[0] == expected_output[0]
assert actual_output[1] == expected_output[1]
def test_model_field_to_image_input():
"""Tests mapping an image field."""
class ModelWithImageField(BaseModel):
image: Optional[str] = Field(
default=None,
description="An image for the video generation",
)
expected_output = (
IO.IMAGE,
{
"default": None,
"tooltip": "An image for the video generation",
},
)
actual_output = model_field_to_node_input(IO.IMAGE, ModelWithImageField, "image")
assert actual_output[0] == expected_output[0]
assert actual_output[1] == expected_output[1]
def test_model_field_to_node_input_no_description():
"""Tests mapping a field with no description."""
class ModelWithNoDescriptionField(BaseModel):
field: Optional[str] = Field(default="default value")
expected_output = (
IO.STRING,
{
"default": "default value",
},
)
actual_output = model_field_to_node_input(
IO.STRING, ModelWithNoDescriptionField, "field"
)
assert actual_output[0] == expected_output[0]
assert actual_output[1] == expected_output[1]
def test_model_field_to_node_input_no_default():
"""Tests mapping a field with no default."""
class ModelWithNoDefaultField(BaseModel):
field: Optional[str] = Field(description="A field with no default")
expected_output = (
IO.STRING,
{
"tooltip": "A field with no default",
},
)
actual_output = model_field_to_node_input(
IO.STRING, ModelWithNoDefaultField, "field"
)
assert actual_output[0] == expected_output[0]
assert actual_output[1] == expected_output[1]
def test_model_field_to_node_input_no_metadata():
"""Tests mapping a field with no metadata or properties defined on the schema."""
class ModelWithNoMetadataField(BaseModel):
field: Optional[str] = Field()
expected_output = (
IO.STRING,
{},
)
actual_output = model_field_to_node_input(
IO.STRING, ModelWithNoMetadataField, "field"
)
assert actual_output[0] == expected_output[0]
assert actual_output[1] == expected_output[1]
def test_model_field_to_node_input_default_is_none():
"""
Tests mapping a field with a default of `None`.
I.e., the default field should be included as the schema explicitly sets it to `None`.
"""
class ModelWithNoneDefaultField(BaseModel):
field: Optional[str] = Field(
default=None, description="A field with a default of None"
)
expected_output = (
IO.STRING,
{
"default": None,
"tooltip": "A field with a default of None",
},
)
actual_output = model_field_to_node_input(
IO.STRING, ModelWithNoneDefaultField, "field"
)
assert actual_output[0] == expected_output[0]
assert actual_output[1] == expected_output[1]

View File

@@ -0,0 +1,150 @@
import pytest
import torch
import tempfile
import os
import av
from fractions import Fraction
from comfy_api.input_impl.video_types import (
VideoFromFile,
VideoFromComponents,
SliceOp,
)
from comfy_api.util.video_types import VideoComponents
def create_test_video(width=4, height=4, frames=10, fps=30):
"""Helper to create a temporary video file."""
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
with av.open(tmp.name, mode="w") as container:
stream = container.add_stream("h264", rate=fps)
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
for i in range(frames):
frame_data = torch.ones(height, width, 3, dtype=torch.uint8) * (i * 25)
frame = av.VideoFrame.from_ndarray(frame_data.numpy(), format="rgb24")
frame = frame.reformat(format="yuv420p")
packet = stream.encode(frame)
container.mux(packet)
packet = stream.encode(None)
container.mux(packet)
return tmp.name
@pytest.fixture
def video_file_10_frames():
file_path = create_test_video(frames=10)
yield file_path
os.unlink(file_path)
@pytest.fixture
def video_components_10_frames():
images = torch.rand(10, 4, 4, 3)
return VideoComponents(images=images, frame_rate=Fraction(30))
class TestSliceOp:
def test_apply_slices_correctly(self, video_components_10_frames):
op = SliceOp(start_frame=2, frame_count=3)
result = op.apply(video_components_10_frames)
assert result.images.shape[0] == 3
assert torch.equal(result.images, video_components_10_frames.images[2:5])
def test_compute_frame_count(self):
op = SliceOp(start_frame=2, frame_count=5)
assert op.compute_frame_count(10) == 5
def test_compute_frame_count_clamps(self):
op = SliceOp(start_frame=8, frame_count=5)
assert op.compute_frame_count(10) == 2
class TestVideoSliced:
def test_sliced_returns_new_instance(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
assert video is not sliced
assert len(video._operations) == 0
assert len(sliced._operations) == 1
def test_get_components_applies_operations(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
components = sliced.get_components()
assert components.images.shape[0] == 3
assert torch.equal(components.images, video_components_10_frames.images[2:5])
def test_get_frame_count(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
assert sliced.get_frame_count() == 3
def test_get_duration(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(0, 3)
assert sliced.get_duration() == pytest.approx(0.1)
def test_chained_slices_compose(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 6).sliced(1, 3)
components = sliced.get_components()
assert components.images.shape[0] == 3
assert torch.equal(components.images, video_components_10_frames.images[3:6])
def test_operations_list_is_immutable(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced1 = video.sliced(0, 5)
sliced2 = sliced1.sliced(1, 2)
assert len(video._operations) == 0
assert len(sliced1._operations) == 1
assert len(sliced2._operations) == 2
def test_from_file(self, video_file_10_frames):
video = VideoFromFile(video_file_10_frames)
sliced = video.sliced(2, 3)
components = sliced.get_components()
assert components.images.shape[0] == 3
assert sliced.get_frame_count() == 3
def test_save_sliced_video(self, video_components_10_frames, tmp_path):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
output_path = str(tmp_path / "sliced_output.mp4")
sliced.save_to(output_path)
saved_video = VideoFromFile(output_path)
assert saved_video.get_frame_count() == 3
def test_materialization_clears_ops(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
assert len(sliced._operations) == 1
sliced.get_components()
assert len(sliced._operations) == 0
def test_second_get_components_uses_cache(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
first = sliced.get_components()
second = sliced.get_components()
assert first.images.shape == second.images.shape
assert torch.equal(first.images, second.images)