Compare commits

..

96 Commits

Author SHA1 Message Date
Jedrzej Kosinski
295b49c165 Doing some experimentation 2025-09-02 22:19:12 -07:00
Jedrzej Kosinski
a40c5ae341 Support predict_ratio changing with timesteps 2025-09-02 15:23:28 -07:00
Jedrzej Kosinski
953b906f63 Implement Sortblock for single cond usage 2025-09-02 00:45:59 -07:00
Jedrzej Kosinski
d4a8752c8c some exploration of sortblock as more things from paper/source code need to be added 2025-09-01 09:39:40 -07:00
Jedrzej Kosinski
cf26d3d58e More progress on Sortblock 2025-08-31 20:26:49 -07:00
Jedrzej Kosinski
f655fcc5ce Progress on scaffolding for an EasyCache style implementation of Sortblock 2025-08-31 00:59:01 -07:00
Jedrzej Kosinski
e2491f44e8 Merge branch 'attention-select' into sortblock 2025-08-30 20:04:48 -07:00
Jedrzej Kosinski
66c4eb006b Remove AttentionOverrideTest node, that's something to cook up for later 2025-08-30 15:19:36 -07:00
Jedrzej Kosinski
dd0a5093f6 Satisfy ruff 2025-08-30 14:58:30 -07:00
Jedrzej Kosinski
c092b8a4ac Remove _register_core_attention_functions, as we wouldn't want someone to call that, just in case 2025-08-30 14:49:04 -07:00
Jedrzej Kosinski
eaa9433ff8 Remove attention logging code 2025-08-30 14:45:12 -07:00
comfyanonymous
4449e14769 ComfyUI version 0.3.56 2025-08-30 06:31:19 -04:00
Jedrzej Kosinski
720d0a88e6 Disable attention logs for now 2025-08-30 01:11:34 -07:00
Jedrzej Kosinski
d9bb4530b0 Merge branch 'master' into attention-select 2025-08-29 23:35:38 -07:00
Jedrzej Kosinski
cb959f9669 Add optimized to get_attention_function 2025-08-29 21:48:36 -07:00
comfyanonymous
885015eecf Lower ram usage on windows. (#9628) 2025-08-29 23:06:04 -04:00
Jedrzej Kosinski
d553073a1e Fixed WAN 2.1 VACE transformer_options passthrough 2025-08-29 13:20:43 -07:00
Jedrzej Kosinski
af288b9946 Fixed Wan2.1 Fun Camera transformer_options passthrough 2025-08-29 13:06:37 -07:00
comfyanonymous
a86aaa4301 ComfyUI v0.3.55 2025-08-29 06:03:41 -04:00
ComfyUI Wiki
2efb2cbc38 Update template to 0.1.70 (#9620) 2025-08-29 06:03:25 -04:00
Jedrzej Kosinski
1ae6fe14a7 Fix WanI2VCrossAttention so that it expects to receive transformer_options 2025-08-29 02:31:16 -07:00
comfyanonymous
15aa9222c4 Trim audio to video when saving video. (#9617) 2025-08-29 04:12:00 -04:00
Jedrzej Kosinski
2d13bf1c7a Made SVD work with optimized_attention_override 2025-08-28 22:45:45 -07:00
Jedrzej Kosinski
8be3edb606 Made Chroma work with optimized_attention_override 2025-08-28 22:45:31 -07:00
Jedrzej Kosinski
d644aba6bc Made Lumina work with optimized_attention_override 2025-08-28 22:00:44 -07:00
Jedrzej Kosinski
17090c56be Made AuraFlow work with optimized_attention_override 2025-08-28 21:46:56 -07:00
Jedrzej Kosinski
034d6c12e6 Made StableCascade work with optimized_attention_override 2025-08-28 21:42:08 -07:00
Jedrzej Kosinski
09c84b31a2 Made Omnigen 2 work with optimized_attention_override 2025-08-28 21:30:18 -07:00
Jedrzej Kosinski
8fe2dea297 Made CosmosVideo work with optimized_attention_override 2025-08-28 21:23:03 -07:00
Jedrzej Kosinski
4a44ed4a76 Make CosmosPredict2 work with optimized_attention_override 2025-08-28 21:18:34 -07:00
Jedrzej Kosinski
8b9b4bbb62 Made Hunyuan3D work with optimized_attention_override 2025-08-28 21:06:44 -07:00
Jedrzej Kosinski
27ebd312ae Made optimized_attention_override work with ACE Step 2025-08-28 21:03:28 -07:00
Jedrzej Kosinski
9461f30387 Made StableAudio work with optimized_attention_override 2025-08-28 20:56:56 -07:00
Jedrzej Kosinski
2cda45d1b4 Made LTX work with optimized_attention_override 2025-08-28 20:42:22 -07:00
Jedrzej Kosinski
61b5c5fc75 Made Mochi work with optimized_attention_override 2025-08-28 20:34:06 -07:00
Jedrzej Kosinski
ef894cdf08 Made HunyuanVideo work with optimized_attention_override 2025-08-28 20:26:53 -07:00
Jedrzej Kosinski
0ac5c6344f Made SD3 work with optimized_attention_override 2025-08-28 20:21:14 -07:00
Jedrzej Kosinski
1ddfb5bb14 Made wan patches_replace work with optimized_attention_override 2025-08-28 20:13:51 -07:00
Jedrzej Kosinski
4cafd58f71 Made hidream work with optimized_attention_override 2025-08-28 20:10:50 -07:00
Jedrzej Kosinski
f752715aac Make Qwen work with optimized_attention_override 2025-08-28 19:52:52 -07:00
comfyanonymous
c7bb3e2bce Support the 5B fun inpaint model. (#9614)
Use the WanFunInpaintToVideo node without the clip_vision_output.
2025-08-28 22:46:57 -04:00
Jedrzej Kosinski
48ed71caf8 Add logs to verify optimized_attention_override is passed all the way into attention function 2025-08-28 19:43:39 -07:00
Jedrzej Kosinski
a7d70e42a0 Make flux work with optimized_attention_override 2025-08-28 19:33:02 -07:00
comfyanonymous
e80a14ad50 Support wan2.2 5B fun control model. (#9611)
Use the Wan22FunControlToVideo node.
2025-08-28 22:13:07 -04:00
Jedrzej Kosinski
1f499f0794 Turn off attention logging for now, make AttentionOverrideTestNode have a dropdown with available attention (this is a test node only) 2025-08-28 18:54:22 -07:00
Jedrzej Kosinski
51a30c2ad7 Make sure wrap_attn doesn't make itself recurse infinitely, attempt to load SageAttention and FlashAttention if not enabled so that they can be marked as available or not, create registry for available attention 2025-08-28 18:53:20 -07:00
comfyanonymous
d28b39d93d Add a LatentCut node to cut latents. (#9609) 2025-08-28 19:38:28 -04:00
comfyanonymous
1c184c29eb Fix issue with s2v node when extending past audio length. (#9608) 2025-08-28 18:34:01 -04:00
comfyanonymous
edde0b5043 WanSoundImageToVideoExtend node to manually extend s2v video. (#9606) 2025-08-28 17:59:48 -04:00
Jedrzej Kosinski
669b9ef8e6 Added **kwargs to all attention functions so transformer_options could potentially be passed through 2025-08-28 13:14:41 -07:00
comfyanonymous
0063610177 ComfyUI version 0.3.54 2025-08-28 10:44:57 -04:00
comfyanonymous
ce0052c087 Fix diffsynth controlnet regression. (#9597) 2025-08-28 10:37:42 -04:00
comfyanonymous
0eb821a7b6 ComfyUI 0.3.53 2025-08-27 23:09:06 -04:00
comfyanonymous
4aa79dbf2c Adjust flux mem usage factor a bit. (#9588) 2025-08-27 23:08:17 -04:00
comfyanonymous
38f697d953 Add a LatentConcat node. (#9587) 2025-08-27 22:28:10 -04:00
Jedrzej Kosinski
dd21b4aa51 Made WAN attention receive transformer_options, test node added to wan to test out attention override later 2025-08-27 17:56:21 -07:00
Jedrzej Kosinski
29b7990dc2 Fix memory usage issue with inspect 2025-08-27 17:55:35 -07:00
Jedrzej Kosinski
68b00e9c60 Created logging code for this branch so that it can be used to track down all the code paths where transformer_options would need to be added 2025-08-27 17:13:33 -07:00
Gangin Park
3aad339b63 Add DPM++ 2M SDE Heun (RES) sampler (#9542) 2025-08-27 19:07:31 -04:00
comfyanonymous
491755325c Better s2v memory estimation. (#9584) 2025-08-27 19:02:42 -04:00
Jedrzej Kosinski
b58db6934c Looking into a @wrap_attn decorator to look for 'optimized_attention_override' entry in transformer_options 2025-08-27 14:18:18 -07:00
comfyanonymous
496888fd68 Improve s2v performance when generating videos longer than 120 frames. (#9582) 2025-08-27 16:06:40 -04:00
comfyanonymous
b5ac6ed7ce Fixes to make controlnet type models work on qwen edit and kontext. (#9581) 2025-08-27 15:26:28 -04:00
Kohaku-Blueleaf
b20ba1f27c Fix #9537 (#9576) 2025-08-27 12:45:02 -04:00
comfyanonymous
31a37686d0 Negative audio in s2v should be zeros. (#9578) 2025-08-27 12:44:29 -04:00
comfyanonymous
88aee596a3 WIP Wan 2.2 S2V model. (#9568) 2025-08-27 01:10:34 -04:00
ComfyUI Wiki
6a193ac557 Update template to 0.1.68 (#9569)
* Update template to 0.1.67

* Update template to 0.1.68
2025-08-27 00:10:20 -04:00
Jedrzej Kosinski
47f4db3e84 Adding Google Gemini Image API node (#9566)
* bigcat88's progress on adding Google Gemini Image node

* Made Google Gemini Image node functional

* Bump frontend version to get static pricing badge on Gemini Image node
2025-08-26 22:20:44 -04:00
ComfyUI Wiki
5352abc6d3 Update template to 0.1.66 (#9557) 2025-08-26 13:33:54 -04:00
comfyanonymous
39aa06bd5d Make AudioEncoderOutput usable in v3 node schema. (#9554) 2025-08-26 12:50:46 -04:00
comfyanonymous
914c2a2973 Implement wav2vec2 as an audio encoder model. (#9549)
This is useless on its own but there are multiple models that use it.
2025-08-25 23:26:47 -04:00
comfyanonymous
e633a47ad1 Add models/audio_encoders directory. (#9548) 2025-08-25 20:13:54 -04:00
comfyanonymous
f6b93d41a0 Remove models from readme that are not fully implemented. (#9535)
Cosmos model implementations are currently missing the safety part so it is technically not fully implemented and should not be advertised as such.
2025-08-24 15:40:32 -04:00
blepping
95ac7794b7 Fix EasyCache/LazyCache crash when tensor shape/dtype/device changes during sampling (#9528)
* Fix EasyCache/LazyCache crash when tensor shape/dtype/device changes during sampling

* Fix missing LazyCache check_metadata method
Ensure LazyCache reset method resets all the tensor state values
2025-08-24 15:29:49 -04:00
comfyanonymous
71ed4a399e ComfyUI version 0.3.52 2025-08-23 18:57:09 -04:00
Christian Byrne
3e316c6338 Update frontend to v1.25.10 and revert navigation mode override (#9522)
- Update comfyui-frontend-package from 1.25.9 to 1.25.10
- Revert forced legacy navigation mode from PR #9518
- Frontend v1.25.10 includes proper navigation mode fixes and improved display text
2025-08-23 17:54:01 -04:00
comfyanonymous
8be0d22ab7 Don't use the annoying new navigation mode by default. (#9518) 2025-08-23 13:56:17 -04:00
comfyanonymous
59eddda900 Python 3.13 is well supported. (#9511) 2025-08-23 01:36:44 -04:00
comfyanonymous
41048c69b4 Fix Conditioning masks on 3d latents. (#9506) 2025-08-22 23:15:44 -04:00
Jedrzej Kosinski
fc247150fe Implement EasyCache and Invent LazyCache (#9496)
* Attempting a universal implementation of EasyCache, starting with flux as test; I screwed up the math a bit, but when I set it just right it works.

* Fixed math to make threshold work as expected, refactored code to use EasyCacheHolder instead of a dict wrapped by object

* Use sigmas from transformer_options instead of timesteps to be compatible with a greater amount of models, make end_percent work

* Make log statement when not skipping useful, preparing for per-cond caching

* Added DIFFUSION_MODEL wrapper around forward function for wan model

* Add subsampling for heuristic inputs

* Add subsampling to output_prev (output_prev_subsampled now)

* Properly consider conds in EasyCache logic

* Created SuperEasyCache to test what happens if caching and reuse is moved outside the scope of conds, added PREDICT_NOISE wrapper to facilitate this test

* Change max reuse_threshold to 3.0

* Mark EasyCache/SuperEasyCache as experimental (beta)

* Make Lumina2 compatible with EasyCache

* Add EasyCache support for Qwen Image

* Fix missing comma, curse you Cursor

* Add EasyCache support to AceStep

* Add EasyCache support to Chroma

* Added EasyCache support to Cosmos Predict t2i

* Make EasyCache not crash with Cosmos Predict ImagToVideo latents, but does not work well at all

* Add EasyCache support to hidream

* Added EasyCache support to hunyuan video

* Added EasyCache support to hunyuan3d

* Added EasyCache support to LTXV (not very good, but does not crash)

* Implemented EasyCache for aura_flow

* Renamed SuperEasyCache to LazyCache, hardcoded subsample_factor to 8 on nodes

* Eatra logging when verbose is true for EasyCache
2025-08-22 22:41:08 -04:00
contentis
fe31ad0276 Add elementwise fusions (#9495)
* Add elementwise fusions

* Add addcmul pattern to Qwen
2025-08-22 19:39:15 -04:00
ComfyUI Wiki
ca4e96a8ae Update template to 0.1.65 (#9501) 2025-08-22 17:40:18 -04:00
Alexander Piskun
050c67323c feat(api-nodes): add copy button to Gemini Chat node (#9440) 2025-08-22 10:51:14 -07:00
Alexander Piskun
497d41fb50 feat(api-nodes): change "OpenAI Chat" display name to "OpenAI ChatGPT" (#9443) 2025-08-22 10:50:35 -07:00
comfyanonymous
ff57793659 Support InstantX Qwen controlnet. (#9488) 2025-08-22 00:53:11 -04:00
comfyanonymous
f7bd5e58dd Make it easier to implement future qwen controlnets. (#9485) 2025-08-21 23:18:04 -04:00
Alexander Piskun
7ed73d12d1 [V3] convert Ideogram API nodes to the V3 schema (#9278)
* convert Ideogram API nodes to the V3 schema

* use auth_kwargs instead of auth_token/comfy_api_key
2025-08-21 22:06:51 -04:00
Alexander Piskun
eb39019daa [V3] convert Google Veo API node to the V3 schema (#9272)
* convert Google Veo API node to the V3 schema

* use own full io.Schema for Veo3VideoGenerationNode

* fixed typo

* use auth_kwargs instead of auth_token/comfy_api_key
2025-08-21 22:06:13 -04:00
Alexander Piskun
bab08f40d1 v3 nodes (part a) (#9149) 2025-08-21 22:05:36 -04:00
Alexander Piskun
bc49106837 convert String nodes to V3 schema (#9370) 2025-08-21 22:03:57 -04:00
comfyanonymous
1b2de2642d Support diffsynth inpaint controlnet (model patch). (#9471) 2025-08-21 00:33:49 -04:00
comfyanonymous
9fa1036f60 Forgot this. (#9470) 2025-08-20 23:09:35 -04:00
saurabh-pingale
0737b7e0d2 fix(userdata): catch invalid workflow filenames (#9434) (#9445) 2025-08-20 22:27:57 -04:00
comfyanonymous
0963493a9c Support for Qwen Diffsynth Controlnets canny and depth. (#9465)
These are not real controlnets but actually a patch on the model so they
will be treated as such.

Put them in the models/model_patches/ folder.

Use the new ModelPatchLoader and QwenImageDiffsynthControlnet nodes.
2025-08-20 22:26:37 -04:00
comfyanonymous
e73a9dbe30 Add that qwen edit model is supported to readme. (#9463) 2025-08-20 17:34:13 -04:00
Harel Cain
fe01885acf LTXV: fix key frame noise mask dimensions for when real noise mask exists (#9425) 2025-08-20 03:33:10 -04:00
69 changed files with 4789 additions and 1171 deletions

View File

@@ -65,18 +65,17 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
- [Qwen Image Edit](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/#edit-model)
- Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
- Audio Models
@@ -191,7 +190,7 @@ comfy install
## Manual Install (Windows, Linux)
python 3.13 is supported but using 3.12 is recommended because some custom nodes and their dependencies might not support it yet.
Python 3.13 is very well supported. If you have trouble with some custom node dependencies you can try 3.12
Git clone this repo.

View File

@@ -363,10 +363,17 @@ class UserManager():
if not overwrite and os.path.exists(path):
return web.Response(status=409, text="File already exists")
body = await request.read()
try:
body = await request.read()
with open(path, "wb") as f:
f.write(body)
with open(path, "wb") as f:
f.write(body)
except OSError as e:
logging.warning(f"Error saving file '{path}': {e}")
return web.Response(
status=400,
reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|"
)
user_path = self.get_request_user_filepath(request, None)
if full_info:

View File

@@ -0,0 +1,42 @@
from .wav2vec2 import Wav2Vec2Model
import comfy.model_management
import comfy.ops
import comfy.utils
import logging
import torchaudio
class AudioEncoderModel():
def __init__(self, config):
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
def get_sd(self):
return self.model.state_dict()
def encode_audio(self, audio, sample_rate):
comfy.model_management.load_model_gpu(self.patcher)
audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
out, all_layers = self.model(audio.to(self.load_device))
outputs = {}
outputs["encoded_audio"] = out
outputs["encoded_audio_all_layers"] = all_layers
return outputs
def load_audio_encoder_from_sd(sd, prefix=""):
audio_encoder = AudioEncoderModel(None)
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
m, u = audio_encoder.load_sd(sd)
if len(m) > 0:
logging.warning("missing audio encoder: {}".format(m))
return audio_encoder

View File

@@ -0,0 +1,207 @@
import torch
import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention_masked
class LayerNormConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
super().__init__()
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype)
def forward(self, x):
x = self.conv(x)
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
class ConvFeatureEncoder(nn.Module):
def __init__(self, conv_dim, dtype=None, device=None, operations=None):
super().__init__()
self.conv_layers = nn.ModuleList([
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
])
def forward(self, x):
x = x.unsqueeze(1)
for conv in self.conv_layers:
x = conv(x)
return x.transpose(1, 2)
class FeatureProjection(nn.Module):
def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None):
super().__init__()
self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype)
self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype)
def forward(self, x):
x = self.layer_norm(x)
x = self.projection(x)
return x
class PositionalConvEmbedding(nn.Module):
def __init__(self, embed_dim=768, kernel_size=128, groups=16):
super().__init__()
self.conv = nn.Conv1d(
embed_dim,
embed_dim,
kernel_size=kernel_size,
padding=kernel_size // 2,
groups=groups,
)
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
self.activation = nn.GELU()
def forward(self, x):
x = x.transpose(1, 2)
x = self.conv(x)[:, :, :-1]
x = self.activation(x)
x = x.transpose(1, 2)
return x
class TransformerEncoder(nn.Module):
def __init__(
self,
embed_dim=768,
num_heads=12,
num_layers=12,
mlp_ratio=4.0,
dtype=None, device=None, operations=None
):
super().__init__()
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
self.layers = nn.ModuleList([
TransformerEncoderLayer(
embed_dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
device=device, dtype=dtype, operations=operations
)
for _ in range(num_layers)
])
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
def forward(self, x, mask=None):
x = x + self.pos_conv_embed(x)
all_x = ()
for layer in self.layers:
all_x += (x,)
x = layer(x, mask)
x = self.layer_norm(x)
all_x += (x,)
return x, all_x
class Attention(nn.Module):
def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
def forward(self, x, mask=None):
assert (mask is None) # TODO?
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
out = optimized_attention_masked(q, k, v, self.num_heads)
return self.out_proj(out)
class FeedForward(nn.Module):
def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None):
super().__init__()
self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype)
self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype)
def forward(self, x):
x = self.intermediate_dense(x)
x = torch.nn.functional.gelu(x)
x = self.output_dense(x)
return x
class TransformerEncoderLayer(nn.Module):
def __init__(
self,
embed_dim=768,
num_heads=12,
mlp_ratio=4.0,
dtype=None, device=None, operations=None
):
super().__init__()
self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations)
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
def forward(self, x, mask=None):
residual = x
x = self.layer_norm(x)
x = self.attention(x, mask=mask)
x = residual + x
x = x + self.feed_forward(self.final_layer_norm(x))
return x
class Wav2Vec2Model(nn.Module):
"""Complete Wav2Vec 2.0 model."""
def __init__(
self,
embed_dim=1024,
final_dim=256,
num_heads=16,
num_layers=24,
dtype=None, device=None, operations=None
):
super().__init__()
conv_dim = 512
self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations)
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
self.encoder = TransformerEncoder(
embed_dim=embed_dim,
num_heads=num_heads,
num_layers=num_layers,
device=device, dtype=dtype, operations=operations
)
def forward(self, x, mask_time_indices=None, return_dict=False):
x = torch.mean(x, dim=1)
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
features = self.feature_extractor(x)
features = self.feature_projection(features)
batch_size, seq_len, _ = features.shape
x, all_x = self.encoder(features)
return x, all_x

View File

@@ -36,6 +36,7 @@ import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet
import comfy.ldm.qwen_image.controlnet
import comfy.cldm.dit_embedder
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@@ -236,11 +237,11 @@ class ControlNet(ControlBase):
self.cond_hint = None
compression_ratio = self.compression_ratio
if self.vae is not None:
compression_ratio *= self.vae.downscale_ratio
compression_ratio *= self.vae.spacial_compression_encode()
else:
if self.latent_format is not None:
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center")
self.cond_hint = self.preprocess_image(self.cond_hint)
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
@@ -582,6 +583,15 @@ def load_controlnet_flux_instantx(sd, model_options={}):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet_qwen_instantx(sd, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
latent_format = comfy.latent_formats.Wan21()
extra_conds = []
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
@@ -655,8 +665,11 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
else:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
elif "transformer_blocks.0.img_mlp.net.0.proj.weight" in controlnet_data:
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)

View File

@@ -853,6 +853,11 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
return x
@torch.no_grad()
def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@torch.no_grad()
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""DPM-Solver++(3M) SDE."""
@@ -925,6 +930,16 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
@torch.no_grad()
def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@torch.no_grad()
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1:

View File

@@ -133,6 +133,7 @@ class Attention(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
transformer_options={},
**cross_attention_kwargs,
) -> torch.Tensor:
return self.processor(
@@ -140,6 +141,7 @@ class Attention(nn.Module):
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
transformer_options=transformer_options,
**cross_attention_kwargs,
)
@@ -366,6 +368,7 @@ class CustomerAttnProcessor2_0:
encoder_attention_mask: Optional[torch.FloatTensor] = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
transformer_options={},
*args,
**kwargs,
) -> torch.Tensor:
@@ -433,7 +436,7 @@ class CustomerAttnProcessor2_0:
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = optimized_attention(
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options,
).to(query.dtype)
# linear proj
@@ -697,6 +700,7 @@ class LinearTransformerBlock(nn.Module):
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
temb: torch.FloatTensor = None,
transformer_options={},
):
N = hidden_states.shape[0]
@@ -720,6 +724,7 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
transformer_options=transformer_options,
)
else:
attn_output, _ = self.attn(
@@ -729,6 +734,7 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=None,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=None,
transformer_options=transformer_options,
)
if self.use_adaln_single:
@@ -743,6 +749,7 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
transformer_options=transformer_options,
)
hidden_states = attn_output + hidden_states

View File

@@ -19,6 +19,7 @@ import torch
from torch import nn
import comfy.model_management
import comfy.patcher_extension
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from .attention import LinearTransformerBlock, t2i_modulate
@@ -313,6 +314,7 @@ class ACEStepTransformer2DModel(nn.Module):
output_length: int = 0,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
transformer_options={},
):
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
temb = self.t_block(embedded_timestep)
@@ -338,12 +340,34 @@ class ACEStepTransformer2DModel(nn.Module):
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
temb=temb,
transformer_options=transformer_options,
)
output = self.final_layer(hidden_states, embedded_timestep, output_length)
return output
def forward(
def forward(self,
x,
timestep,
attention_mask=None,
context: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
lyrics_strength=1.0,
**kwargs
):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x, timestep, attention_mask, context, text_attention_mask, speaker_embeds, lyric_token_idx, lyric_mask, block_controlnet_hidden_states,
controlnet_scale, lyrics_strength, **kwargs)
def _forward(
self,
x,
timestep,
@@ -371,6 +395,7 @@ class ACEStepTransformer2DModel(nn.Module):
output_length = hidden_states.shape[-1]
transformer_options = kwargs.get("transformer_options", {})
output = self.decode(
hidden_states=hidden_states,
attention_mask=attention_mask,
@@ -380,6 +405,7 @@ class ACEStepTransformer2DModel(nn.Module):
output_length=output_length,
block_controlnet_hidden_states=block_controlnet_hidden_states,
controlnet_scale=controlnet_scale,
transformer_options=transformer_options,
)
return output

View File

@@ -298,7 +298,8 @@ class Attention(nn.Module):
mask = None,
context_mask = None,
rotary_pos_emb = None,
causal = None
causal = None,
transformer_options={},
):
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
@@ -363,7 +364,7 @@ class Attention(nn.Module):
heads_per_kv_head = h // kv_h
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
out = optimized_attention(q, k, v, h, skip_reshape=True)
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
out = self.to_out(out)
if mask is not None:
@@ -488,7 +489,8 @@ class TransformerBlock(nn.Module):
global_cond=None,
mask = None,
context_mask = None,
rotary_pos_emb = None
rotary_pos_emb = None,
transformer_options={}
):
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
@@ -498,12 +500,12 @@ class TransformerBlock(nn.Module):
residual = x
x = self.pre_norm(x)
x = x * (1 + scale_self) + shift_self
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
x = x * torch.sigmoid(1 - gate_self)
x = x + residual
if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
if self.conformer is not None:
x = x + self.conformer(x)
@@ -517,10 +519,10 @@ class TransformerBlock(nn.Module):
x = x + residual
else:
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
if self.conformer is not None:
x = x + self.conformer(x)
@@ -606,7 +608,8 @@ class ContinuousTransformer(nn.Module):
return_info = False,
**kwargs
):
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
transformer_options = kwargs.get("transformer_options", {})
patches_replace = transformer_options.get("patches_replace", {})
batch, seq, device = *x.shape[:2], x.device
context = kwargs["context"]
@@ -645,13 +648,13 @@ class ContinuousTransformer(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options)
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
if return_info:

View File

@@ -9,6 +9,7 @@ import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
import comfy.patcher_extension
import comfy.ldm.common_dit
def modulate(x, shift, scale):
@@ -84,7 +85,7 @@ class SingleAttention(nn.Module):
)
#@torch.compile()
def forward(self, c):
def forward(self, c, transformer_options={}):
bsz, seqlen1, _ = c.shape
@@ -94,7 +95,7 @@ class SingleAttention(nn.Module):
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
q, k = self.q_norm1(q), self.k_norm1(k)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
c = self.w1o(output)
return c
@@ -143,7 +144,7 @@ class DoubleAttention(nn.Module):
#@torch.compile()
def forward(self, c, x):
def forward(self, c, x, transformer_options={}):
bsz, seqlen1, _ = c.shape
bsz, seqlen2, _ = x.shape
@@ -167,7 +168,7 @@ class DoubleAttention(nn.Module):
torch.cat([cv, xv], dim=1),
)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
c, x = output.split([seqlen1, seqlen2], dim=1)
c = self.w1o(c)
@@ -206,7 +207,7 @@ class MMDiTBlock(nn.Module):
self.is_last = is_last
#@torch.compile()
def forward(self, c, x, global_cond, **kwargs):
def forward(self, c, x, global_cond, transformer_options={}, **kwargs):
cres, xres = c, x
@@ -224,7 +225,7 @@ class MMDiTBlock(nn.Module):
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
# attention
c, x = self.attn(c, x)
c, x = self.attn(c, x, transformer_options=transformer_options)
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
@@ -254,13 +255,13 @@ class DiTBlock(nn.Module):
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
#@torch.compile()
def forward(self, cx, global_cond, **kwargs):
def forward(self, cx, global_cond, transformer_options={}, **kwargs):
cxres = cx
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
global_cond
).chunk(6, dim=1)
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
cx = self.attn(cx)
cx = self.attn(cx, transformer_options=transformer_options)
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
cx = gate_mlp.unsqueeze(1) * mlpout
@@ -436,6 +437,13 @@ class MMDiT(nn.Module):
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, transformer_options, **kwargs)
def _forward(self, x, timestep, context, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
# patchify x, add PE
b, c, h, w = x.shape
@@ -465,13 +473,14 @@ class MMDiT(nn.Module):
out = {}
out["txt"], out["img"] = layer(args["txt"],
args["img"],
args["vec"])
args["vec"],
transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
c = out["txt"]
x = out["img"]
else:
c, x = layer(c, x, global_cond, **kwargs)
c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs)
if len(self.single_layers) > 0:
c_len = c.size(1)
@@ -480,13 +489,13 @@ class MMDiT(nn.Module):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], args["vec"])
out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
cx = out["img"]
else:
cx = layer(cx, global_cond, **kwargs)
cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs)
x = cx[:, c_len:]

View File

@@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module):
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
def forward(self, q, k, v):
def forward(self, q, k, v, transformer_options={}):
q = self.to_q(q)
k = self.to_k(k)
v = self.to_v(v)
out = optimized_attention(q, k, v, self.heads)
out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options)
return self.out_proj(out)
@@ -47,13 +47,13 @@ class Attention2D(nn.Module):
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
def forward(self, x, kv, self_attn=False):
def forward(self, x, kv, self_attn=False, transformer_options={}):
orig_shape = x.shape
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
if self_attn:
kv = torch.cat([x, kv], dim=1)
# x = self.attn(x, kv, kv, need_weights=False)[0]
x = self.attn(x, kv, kv)
x = self.attn(x, kv, kv, transformer_options=transformer_options)
x = x.permute(0, 2, 1).view(*orig_shape)
return x
@@ -114,9 +114,9 @@ class AttnBlock(nn.Module):
operations.Linear(c_cond, c, dtype=dtype, device=device)
)
def forward(self, x, kv):
def forward(self, x, kv, transformer_options={}):
kv = self.kv_mapper(kv)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options)
return x

View File

@@ -173,7 +173,7 @@ class StageB(nn.Module):
clip = self.clip_norm(clip)
return clip
def _down_encode(self, x, r_embed, clip):
def _down_encode(self, x, r_embed, clip, transformer_options={}):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
@@ -187,7 +187,7 @@ class StageB(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@@ -199,7 +199,7 @@ class StageB(nn.Module):
level_outputs.insert(0, x)
return level_outputs
def _up_decode(self, level_outputs, r_embed, clip):
def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
@@ -216,7 +216,7 @@ class StageB(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@@ -228,7 +228,7 @@ class StageB(nn.Module):
x = upscaler(x)
return x
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs):
if pixels is None:
pixels = x.new_zeros(x.size(0), 3, 8, 8)
@@ -245,8 +245,8 @@ class StageB(nn.Module):
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
align_corners=True)
level_outputs = self._down_encode(x, r_embed, clip)
x = self._up_decode(level_outputs, r_embed, clip)
level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options)
x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):

View File

@@ -182,7 +182,7 @@ class StageC(nn.Module):
clip = self.clip_norm(clip)
return clip
def _down_encode(self, x, r_embed, clip, cnet=None):
def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
@@ -201,7 +201,7 @@ class StageC(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@@ -213,7 +213,7 @@ class StageC(nn.Module):
level_outputs.insert(0, x)
return level_outputs
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
@@ -235,7 +235,7 @@ class StageC(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@@ -247,7 +247,7 @@ class StageC(nn.Module):
x = upscaler(x)
return x
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs):
# Process the conditioning embeddings
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
for c in self.t_conds:
@@ -262,8 +262,8 @@ class StageC(nn.Module):
# Model Blocks
x = self.embedding(x)
level_outputs = self._down_encode(x, r_embed, clip, cnet)
x = self._up_decode(level_outputs, r_embed, clip, cnet)
level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options)
x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):

View File

@@ -76,7 +76,7 @@ class DoubleStreamBlock(nn.Module):
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
@@ -95,7 +95,7 @@ class DoubleStreamBlock(nn.Module):
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
pe=pe, mask=attn_mask, transformer_options=transformer_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
@@ -148,7 +148,7 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh")
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
mod = vec
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@@ -157,7 +157,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask)
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x.addcmul_(mod.gate, output)

View File

@@ -5,6 +5,7 @@ from dataclasses import dataclass
import torch
from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.patcher_extension
import comfy.ldm.common_dit
from comfy.ldm.flux.layers import (
@@ -192,14 +193,16 @@ class Chroma(nn.Module):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": double_mod,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
@@ -208,7 +211,8 @@ class Chroma(nn.Module):
txt=txt,
vec=double_mod,
pe=pe,
attn_mask=attn_mask)
attn_mask=attn_mask,
transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
@@ -228,17 +232,19 @@ class Chroma(nn.Module):
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": single_mod,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
@@ -253,6 +259,13 @@ class Chroma(nn.Module):
return img
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, guidance, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))

View File

@@ -176,6 +176,7 @@ class Attention(nn.Module):
context=None,
mask=None,
rope_emb=None,
transformer_options={},
**kwargs,
):
"""
@@ -184,7 +185,7 @@ class Attention(nn.Module):
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options)
del q, k, v
out = rearrange(out, " b n s c -> s b (n c)")
return self.to_out(out)
@@ -546,6 +547,7 @@ class VideoAttn(nn.Module):
context: Optional[torch.Tensor] = None,
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
"""
Forward pass for video attention.
@@ -571,6 +573,7 @@ class VideoAttn(nn.Module):
context_M_B_D,
crossattn_mask,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
)
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
return x_T_H_W_B_D
@@ -665,6 +668,7 @@ class DITBuildingBlock(nn.Module):
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
"""
Forward pass for dynamically configured blocks with adaptive normalization.
@@ -702,6 +706,7 @@ class DITBuildingBlock(nn.Module):
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
context=None,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
transformer_options=transformer_options,
)
elif self.block_type in ["cross_attn", "ca"]:
x = x + gate_1_1_1_B_D * self.block(
@@ -709,6 +714,7 @@ class DITBuildingBlock(nn.Module):
context=crossattn_emb,
crossattn_mask=crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
transformer_options=transformer_options,
)
else:
raise ValueError(f"Unknown block type: {self.block_type}")
@@ -784,6 +790,7 @@ class GeneralDITTransformerBlock(nn.Module):
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
for block in self.blocks:
x = block(
@@ -793,5 +800,6 @@ class GeneralDITTransformerBlock(nn.Module):
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
transformer_options=transformer_options,
)
return x

View File

@@ -27,6 +27,8 @@ from torchvision import transforms
from enum import Enum
import logging
import comfy.patcher_extension
from .blocks import (
FinalLayer,
GeneralDITTransformerBlock,
@@ -435,6 +437,42 @@ class GeneralDIT(nn.Module):
latent_condition_sigma: Optional[torch.Tensor] = None,
condition_video_augment_sigma: Optional[torch.Tensor] = None,
**kwargs,
):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x,
timesteps,
context,
attention_mask,
fps,
image_size,
padding_mask,
scalar_feature,
data_type,
latent_condition,
latent_condition_sigma,
condition_video_augment_sigma,
**kwargs)
def _forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
# crossattn_emb: torch.Tensor,
# crossattn_mask: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
image_size: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
scalar_feature: Optional[torch.Tensor] = None,
data_type: Optional[DataType] = DataType.VIDEO,
latent_condition: Optional[torch.Tensor] = None,
latent_condition_sigma: Optional[torch.Tensor] = None,
condition_video_augment_sigma: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Args:
@@ -482,6 +520,7 @@ class GeneralDIT(nn.Module):
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
transformer_options = kwargs.get("transformer_options", {})
for _, block in self.blocks.items():
assert (
self.blocks["block0"].x_format == block.x_format
@@ -496,6 +535,7 @@ class GeneralDIT(nn.Module):
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
transformer_options=transformer_options,
)
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")

View File

@@ -11,6 +11,7 @@ import math
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
from torchvision import transforms
import comfy.patcher_extension
from comfy.ldm.modules.attention import optimized_attention
def apply_rotary_pos_emb(
@@ -43,7 +44,7 @@ class GPT2FeedForward(nn.Module):
return x
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
"""Computes multi-head attention using PyTorch's native implementation.
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
@@ -70,7 +71,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options)
class Attention(nn.Module):
@@ -179,8 +180,8 @@ class Attention(nn.Module):
return q, k, v
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
result = self.attn_op(q, k, v) # [B, S, H, D]
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D]
return self.output_dropout(self.output_proj(result))
def forward(
@@ -188,6 +189,7 @@ class Attention(nn.Module):
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
"""
Args:
@@ -195,7 +197,7 @@ class Attention(nn.Module):
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
return self.compute_attention(q, k, v)
return self.compute_attention(q, k, v, transformer_options=transformer_options)
class Timesteps(nn.Module):
@@ -458,6 +460,7 @@ class Block(nn.Module):
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
@@ -511,6 +514,7 @@ class Block(nn.Module):
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
None,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
),
"b (t h w) d -> b t h w d",
t=T,
@@ -524,6 +528,7 @@ class Block(nn.Module):
layer_norm_cross_attn: Callable,
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
_normalized_x_B_T_H_W_D = _fn(
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
@@ -533,6 +538,7 @@ class Block(nn.Module):
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
),
"b (t h w) d -> b t h w d",
t=T,
@@ -546,6 +552,7 @@ class Block(nn.Module):
self.layer_norm_cross_attn,
scale_cross_attn_B_T_1_1_D,
shift_cross_attn_B_T_1_1_D,
transformer_options=transformer_options,
)
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
@@ -805,7 +812,21 @@ class MiniTrainDIT(nn.Module):
)
return x_B_C_Tt_Hp_Wp
def forward(
def forward(self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
**kwargs,
):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x, timesteps, context, fps, padding_mask, **kwargs)
def _forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
@@ -850,6 +871,7 @@ class MiniTrainDIT(nn.Module):
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
"transformer_options": kwargs.get("transformer_options", {}),
}
for block in self.blocks:
x_B_T_H_W_D = block(

View File

@@ -159,7 +159,7 @@ class DoubleStreamBlock(nn.Module):
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
@@ -182,7 +182,7 @@ class DoubleStreamBlock(nn.Module):
attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
pe=pe, mask=attn_mask)
pe=pe, mask=attn_mask, transformer_options=transformer_options)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
@@ -190,7 +190,7 @@ class DoubleStreamBlock(nn.Module):
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
pe=pe, mask=attn_mask, transformer_options=transformer_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
@@ -244,7 +244,7 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
mod, _ = self.modulation(vec)
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@@ -252,7 +252,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask)
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += apply_mod(output, mod.gate, None, modulation_dims)

View File

@@ -6,7 +6,7 @@ from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
q_shape = q.shape
k_shape = k.shape
@@ -17,7 +17,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
return x

View File

@@ -6,6 +6,7 @@ import torch
from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
import comfy.patcher_extension
from .layers import (
DoubleStreamBlock,
@@ -127,6 +128,7 @@ class Flux(nn.Module):
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
transformer_options["block"] = ("double_block", i, 2)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@@ -134,14 +136,16 @@ class Flux(nn.Module):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
@@ -150,14 +154,15 @@ class Flux(nn.Module):
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
attn_mask=attn_mask,
transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
img[:, :add.shape[1]] += add
if img.dtype == torch.float16:
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
@@ -165,30 +170,33 @@ class Flux(nn.Module):
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
transformer_options["block"] = ("single_block", i, 1)
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] :, ...] += add
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
img = img[:, txt.shape[1] :, ...]
@@ -214,6 +222,13 @@ class Flux(nn.Module):
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, guidance, ref_latents, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
bs, c, h_orig, w_orig = x.shape
patch_size = self.patch_size

View File

@@ -109,6 +109,7 @@ class AsymmetricAttention(nn.Module):
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
crop_y,
transformer_options={},
**rope_rotation,
) -> Tuple[torch.Tensor, torch.Tensor]:
rope_cos = rope_rotation.get("rope_cos")
@@ -143,7 +144,7 @@ class AsymmetricAttention(nn.Module):
xy = optimized_attention(q,
k,
v, self.num_heads, skip_reshape=True)
v, self.num_heads, skip_reshape=True, transformer_options=transformer_options)
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
x = self.proj_x(x)
@@ -224,6 +225,7 @@ class AsymmetricJointBlock(nn.Module):
x: torch.Tensor,
c: torch.Tensor,
y: torch.Tensor,
transformer_options={},
**attn_kwargs,
):
"""Forward pass of a block.
@@ -256,6 +258,7 @@ class AsymmetricJointBlock(nn.Module):
y,
scale_x=scale_msa_x,
scale_y=scale_msa_y,
transformer_options=transformer_options,
**attn_kwargs,
)
@@ -524,10 +527,11 @@ class AsymmDiTJoint(nn.Module):
args["txt"],
rope_cos=args["rope_cos"],
rope_sin=args["rope_sin"],
crop_y=args["num_tokens"]
crop_y=args["num_tokens"],
transformer_options=args["transformer_options"]
)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap})
y_feat = out["txt"]
x = out["img"]
else:
@@ -538,6 +542,7 @@ class AsymmDiTJoint(nn.Module):
rope_cos=rope_cos,
rope_sin=rope_sin,
crop_y=num_tokens,
transformer_options=transformer_options,
) # (B, M, D), (B, L, D)
del y_feat # Final layers don't use dense text features.

View File

@@ -13,6 +13,7 @@ from comfy.ldm.flux.layers import LastLayer
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
import comfy.patcher_extension
import comfy.ldm.common_dit
@@ -71,8 +72,8 @@ class TimestepEmbed(nn.Module):
return t_emb
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}):
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options)
class HiDreamAttnProcessor_flashattn:
@@ -85,6 +86,7 @@ class HiDreamAttnProcessor_flashattn:
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
transformer_options={},
*args,
**kwargs,
) -> torch.FloatTensor:
@@ -132,7 +134,7 @@ class HiDreamAttnProcessor_flashattn:
query = torch.cat([query_1, query_2], dim=-1)
key = torch.cat([key_1, key_2], dim=-1)
hidden_states = attention(query, key, value)
hidden_states = attention(query, key, value, transformer_options=transformer_options)
if not attn.single:
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
@@ -198,6 +200,7 @@ class HiDreamAttention(nn.Module):
image_tokens_masks: torch.FloatTensor = None,
norm_text_tokens: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.Tensor:
return self.processor(
self,
@@ -205,6 +208,7 @@ class HiDreamAttention(nn.Module):
image_tokens_masks = image_tokens_masks,
text_tokens = norm_text_tokens,
rope = rope,
transformer_options=transformer_options,
)
@@ -405,7 +409,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
@@ -418,6 +422,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
norm_image_tokens,
image_tokens_masks,
rope = rope,
transformer_options=transformer_options,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
@@ -482,6 +487,7 @@ class HiDreamImageTransformerBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
@@ -499,6 +505,7 @@ class HiDreamImageTransformerBlock(nn.Module):
image_tokens_masks,
norm_text_tokens,
rope = rope,
transformer_options=transformer_options,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
@@ -549,6 +556,7 @@ class HiDreamImageBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.FloatTensor:
return self.block(
image_tokens,
@@ -556,6 +564,7 @@ class HiDreamImageBlock(nn.Module):
text_tokens,
adaln_input,
rope,
transformer_options=transformer_options,
)
@@ -692,7 +701,23 @@ class HiDreamImageTransformer2DModel(nn.Module):
raise NotImplementedError
return x, x_masks, img_sizes
def forward(
def forward(self,
x: torch.Tensor,
t: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
encoder_hidden_states_llama3=None,
image_cond=None,
control = None,
transformer_options = {},
):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, t, y, context, encoder_hidden_states_llama3, image_cond, control, transformer_options)
def _forward(
self,
x: torch.Tensor,
t: torch.Tensor,
@@ -769,6 +794,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
text_tokens = cur_encoder_hidden_states,
adaln_input = adaln_input,
rope = rope,
transformer_options=transformer_options,
)
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
block_id += 1
@@ -792,6 +818,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
text_tokens=None,
adaln_input=adaln_input,
rope=rope,
transformer_options=transformer_options,
)
hidden_states = hidden_states[:, :hidden_states_seq_len]
block_id += 1

View File

@@ -7,6 +7,7 @@ from comfy.ldm.flux.layers import (
SingleStreamBlock,
timestep_embedding,
)
import comfy.patcher_extension
class Hunyuan3Dv2(nn.Module):
@@ -67,6 +68,13 @@ class Hunyuan3Dv2(nn.Module):
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, guidance, transformer_options, **kwargs)
def _forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
x = x.movedim(-1, -2)
timestep = 1.0 - timestep
txt = context
@@ -91,14 +99,16 @@ class Hunyuan3Dv2(nn.Module):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),
transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
@@ -107,7 +117,8 @@ class Hunyuan3Dv2(nn.Module):
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
attn_mask=attn_mask,
transformer_options=transformer_options)
img = torch.cat((txt, img), 1)
@@ -118,17 +129,19 @@ class Hunyuan3Dv2(nn.Module):
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),
transformer_options=args["transformer_options"])
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
img = img[:, txt.shape[1]:, ...]
img = self.final_layer(img, vec)

View File

@@ -1,6 +1,7 @@
#Based on Flux code because of weird hunyuan video code license.
import torch
import comfy.patcher_extension
import comfy.ldm.flux.layers
import comfy.ldm.modules.diffusionmodules.mmdit
from comfy.ldm.modules.attention import optimized_attention
@@ -77,13 +78,13 @@ class TokenRefinerBlock(nn.Module):
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, x, c, mask):
def forward(self, x, c, mask, transformer_options={}):
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn.qkv(norm_x)
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True, transformer_options=transformer_options)
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
@@ -114,14 +115,14 @@ class IndividualTokenRefiner(nn.Module):
]
)
def forward(self, x, c, mask):
def forward(self, x, c, mask, transformer_options={}):
m = None
if mask is not None:
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
m = m + m.transpose(2, 3)
for block in self.blocks:
x = block(x, c, m)
x = block(x, c, m, transformer_options=transformer_options)
return x
@@ -149,6 +150,7 @@ class TokenRefiner(nn.Module):
x,
timesteps,
mask,
transformer_options={},
):
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
# m = mask.float().unsqueeze(-1)
@@ -157,7 +159,7 @@ class TokenRefiner(nn.Module):
c = t + self.c_embedder(c.to(x.dtype))
x = self.input_embedder(x)
x = self.individual_token_refiner(x, c, mask)
x = self.individual_token_refiner(x, c, mask, transformer_options=transformer_options)
return x
class HunyuanVideo(nn.Module):
@@ -266,7 +268,7 @@ class HunyuanVideo(nn.Module):
if txt_mask is not None and not torch.is_floating_point(txt_mask):
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
txt = self.txt_in(txt, timesteps, txt_mask)
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
@@ -284,14 +286,14 @@ class HunyuanVideo(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt, 'transformer_options': transformer_options}, {"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt, transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
@@ -306,13 +308,13 @@ class HunyuanVideo(nn.Module):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims, 'transformer_options': transformer_options}, {"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
@@ -348,6 +350,13 @@ class HunyuanVideo(nn.Module):
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)

View File

@@ -1,5 +1,6 @@
import torch
from torch import nn
import comfy.patcher_extension
import comfy.ldm.modules.attention
import comfy.ldm.common_dit
from einops import rearrange
@@ -270,7 +271,7 @@ class CrossAttention(nn.Module):
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, mask=None, pe=None):
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
q = self.to_q(x)
context = x if context is None else context
k = self.to_k(context)
@@ -284,9 +285,9 @@ class CrossAttention(nn.Module):
k = apply_rotary_emb(k, pe)
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
return self.to_out(out)
@@ -302,12 +303,12 @@ class BasicTransformerBlock(nn.Module):
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
x += self.attn2(x, context=context, mask=attention_mask)
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp
@@ -420,6 +421,13 @@ class LTXVModel(torch.nn.Module):
self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
orig_shape = list(x.shape)
@@ -471,10 +479,10 @@ class LTXVModel(torch.nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
@@ -482,7 +490,8 @@ class LTXVModel(torch.nn.Module):
context=context,
attention_mask=attention_mask,
timestep=timestep,
pe=pe
pe=pe,
transformer_options=transformer_options,
)
# 3. Output

View File

@@ -11,6 +11,7 @@ import comfy.ldm.common_dit
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
import comfy.patcher_extension
def modulate(x, scale):
@@ -103,6 +104,7 @@ class JointAttention(nn.Module):
x: torch.Tensor,
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
transformer_options={},
) -> torch.Tensor:
"""
@@ -139,7 +141,7 @@ class JointAttention(nn.Module):
if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options)
return self.out(output)
@@ -267,6 +269,7 @@ class JointTransformerBlock(nn.Module):
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor]=None,
transformer_options={},
):
"""
Perform a forward pass through the TransformerBlock.
@@ -289,6 +292,7 @@ class JointTransformerBlock(nn.Module):
modulate(self.attention_norm1(x), scale_msa),
x_mask,
freqs_cis,
transformer_options=transformer_options,
)
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
@@ -303,6 +307,7 @@ class JointTransformerBlock(nn.Module):
self.attention_norm1(x),
x_mask,
freqs_cis,
transformer_options=transformer_options,
)
)
x = x + self.ffn_norm2(
@@ -493,7 +498,7 @@ class NextDiT(nn.Module):
return imgs
def patchify_and_embed(
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
bsz = len(x)
pH = pW = self.patch_size
@@ -553,7 +558,7 @@ class NextDiT(nn.Module):
# refine context
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
# refine image
flat_x = []
@@ -572,7 +577,7 @@ class NextDiT(nn.Module):
padded_img_embed = self.x_embedder(padded_img_embed)
padded_img_mask = padded_img_mask.unsqueeze(1)
for layer in self.noise_refiner:
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
if cap_mask is not None:
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
@@ -590,8 +595,15 @@ class NextDiT(nn.Module):
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
# def forward(self, x, t, cap_feats, cap_mask):
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
# def forward(self, x, t, cap_feats, cap_mask):
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
@@ -608,12 +620,13 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
transformer_options = kwargs.get("transformer_options", {})
x_is_tensor = isinstance(x, torch.Tensor)
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(x.device)
for layer in self.layers:
x = layer(x, mask, freqs_cis, adaln_input)
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
x = self.final_layer(x, adaln_input)
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]

View File

@@ -5,8 +5,9 @@ import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional
from typing import Optional, Any, Callable, Union
import logging
import functools
from .diffusionmodules.util import AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention
@@ -17,23 +18,45 @@ if model_management.xformers_enabled():
import xformers
import xformers.ops
if model_management.sage_attention_enabled():
try:
from sageattention import sageattn
except ModuleNotFoundError as e:
SAGE_ATTENTION_IS_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTENTION_IS_AVAILABLE = True
except ModuleNotFoundError as e:
if model_management.sage_attention_enabled():
if e.name == "sageattention":
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
else:
raise e
exit(-1)
if model_management.flash_attention_enabled():
try:
from flash_attn import flash_attn_func
except ModuleNotFoundError:
FLASH_ATTENTION_IS_AVAILABLE = False
try:
from flash_attn import flash_attn_func
FLASH_ATTENTION_IS_AVAILABLE = True
except ModuleNotFoundError:
if model_management.flash_attention_enabled():
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
REGISTERED_ATTENTION_FUNCTIONS = {}
def register_attention_function(name: str, func: Callable):
# avoid replacing existing functions
if name not in REGISTERED_ATTENTION_FUNCTIONS:
REGISTERED_ATTENTION_FUNCTIONS[name] = func
else:
logging.warning(f"Attention function {name} already registered, skipping registration.")
def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]:
if name == "optimized":
return optimized_attention
elif name not in REGISTERED_ATTENTION_FUNCTIONS:
if default is ...:
raise KeyError(f"Attention function {name} not found.")
else:
return default
return REGISTERED_ATTENTION_FUNCTIONS[name]
from comfy.cli_args import args
import comfy.ops
ops = comfy.ops.disable_weight_init
@@ -91,7 +114,27 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
def wrap_attn(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
remove_attn_wrapper_key = False
try:
if "_inside_attn_wrapper" not in kwargs:
transformer_options = kwargs.get("transformer_options", None)
remove_attn_wrapper_key = True
kwargs["_inside_attn_wrapper"] = True
if transformer_options is not None:
if "optimized_attention_override" in transformer_options:
return transformer_options["optimized_attention_override"](func, *args, **kwargs)
return func(*args, **kwargs)
finally:
if remove_attn_wrapper_key:
del kwargs["_inside_attn_wrapper"]
return wrapper
@wrap_attn
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
@@ -159,8 +202,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
)
return out
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
@wrap_attn
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, query.dtype)
if skip_reshape:
@@ -230,7 +273,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
@wrap_attn
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
@@ -359,7 +403,8 @@ try:
except:
pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
@wrap_attn
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
b = q.shape[0]
dim_head = q.shape[-1]
# check to make sure xformers isn't broken
@@ -374,7 +419,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
disabled_xformers = True
if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
if skip_reshape:
# b h k d -> b k h d
@@ -427,8 +472,8 @@ else:
#TODO: other GPUs ?
SDP_BATCH_LIMIT = 2**31
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
@wrap_attn
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@@ -470,8 +515,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout = "HND"
@@ -501,7 +546,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
lambda t: t.transpose(1, 2),
(q, k, v),
)
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
if tensor_layout == "HND":
if not skip_output_reshape:
@@ -534,8 +579,8 @@ except AttributeError as error:
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
@wrap_attn
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@@ -597,6 +642,19 @@ else:
optimized_attention_masked = optimized_attention
# register core-supported attention functions
if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage)
if FLASH_ATTENTION_IS_AVAILABLE:
register_attention_function("flash", attention_flash)
if model_management.xformers_enabled():
register_attention_function("xformers", attention_xformers)
register_attention_function("pytorch", attention_pytorch)
register_attention_function("sub_quad", attention_sub_quad)
register_attention_function("split", attention_split)
def optimized_attention_for_device(device, mask=False, small_input=False):
if small_input:
if model_management.pytorch_attention_enabled():
@@ -629,7 +687,7 @@ class CrossAttention(nn.Module):
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, value=None, mask=None):
def forward(self, x, context=None, value=None, mask=None, transformer_options={}):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
@@ -640,9 +698,9 @@ class CrossAttention(nn.Module):
v = self.to_v(context)
if mask is None:
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
else:
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
return self.to_out(out)
@@ -746,7 +804,7 @@ class BasicTransformerBlock(nn.Module):
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
n = self.attn1.to_out(n)
else:
n = self.attn1(n, context=context_attn1, value=value_attn1)
n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=transformer_options)
if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"]
@@ -786,7 +844,7 @@ class BasicTransformerBlock(nn.Module):
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
n = self.attn2.to_out(n)
else:
n = self.attn2(n, context=context_attn2, value=value_attn2)
n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=transformer_options)
if "attn2_output_patch" in transformer_patches:
patch = transformer_patches["attn2_output_patch"]
@@ -1017,7 +1075,7 @@ class SpatialVideoTransformer(SpatialTransformer):
B, S, C = x_mix.shape
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
x_mix = mix_block(x_mix, context=time_context, transformer_options=transformer_options)
x_mix = rearrange(
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
)

View File

@@ -109,7 +109,7 @@ class PatchEmbed(nn.Module):
def modulate(x, shift, scale):
if shift is None:
shift = torch.zeros_like(scale)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return torch.addcmul(shift.unsqueeze(1), x, 1+ scale.unsqueeze(1))
#################################################################################
@@ -564,10 +564,7 @@ class DismantledBlock(nn.Module):
assert not self.pre_only
attn1 = self.attn.post_attention(attn)
attn2 = self.attn2.post_attention(attn2)
out1 = gate_msa.unsqueeze(1) * attn1
out2 = gate_msa2.unsqueeze(1) * attn2
x = x + out1
x = x + out2
x = gate_cat(x, gate_msa, gate_msa2, attn1, attn2)
x = x + gate_mlp.unsqueeze(1) * self.mlp(
modulate(self.norm2(x), shift_mlp, scale_mlp)
)
@@ -594,6 +591,11 @@ class DismantledBlock(nn.Module):
)
return self.post_attention(attn, *intermediates)
def gate_cat(x, gate_msa, gate_msa2, attn1, attn2):
out1 = gate_msa.unsqueeze(1) * attn1
out2 = gate_msa2.unsqueeze(1) * attn2
x = torch.stack([x, out1, out2], dim=0).sum(dim=0)
return x
def block_mixing(*args, use_checkpoint=True, **kwargs):
if use_checkpoint:
@@ -604,7 +606,7 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
return _block_mixing(*args, **kwargs)
def _block_mixing(context, x, context_block, x_block, c):
def _block_mixing(context, x, context_block, x_block, c, transformer_options={}):
context_qkv, context_intermediates = context_block.pre_attention(context, c)
if x_block.x_block_self_attn:
@@ -620,6 +622,7 @@ def _block_mixing(context, x, context_block, x_block, c):
attn = optimized_attention(
qkv[0], qkv[1], qkv[2],
heads=x_block.attn.num_heads,
transformer_options=transformer_options,
)
context_attn, x_attn = (
attn[:, : context_qkv[0].shape[1]],
@@ -635,6 +638,7 @@ def _block_mixing(context, x, context_block, x_block, c):
attn2 = optimized_attention(
x_qkv2[0], x_qkv2[1], x_qkv2[2],
heads=x_block.attn2.num_heads,
transformer_options=transformer_options,
)
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
else:
@@ -956,10 +960,10 @@ class MMDiT(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod, "transformer_options": transformer_options}, {"original_block": block_wrap})
context = out["txt"]
x = out["img"]
else:
@@ -968,6 +972,7 @@ class MMDiT(nn.Module):
x,
c=c_mod,
use_checkpoint=self.use_checkpoint,
transformer_options=transformer_options,
)
if control is not None:
control_o = control.get("output")

View File

@@ -120,7 +120,7 @@ class Attention(nn.Module):
nn.Dropout(0.0)
)
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
query = self.to_q(hidden_states)
@@ -146,7 +146,7 @@ class Attention(nn.Module):
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
hidden_states = self.to_out[0](hidden_states)
return hidden_states
@@ -182,16 +182,16 @@ class OmniGen2TransformerBlock(nn.Module):
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
if self.modulation:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
else:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
hidden_states = hidden_states + self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
@@ -390,7 +390,7 @@ class OmniGen2Transformer2DModel(nn.Module):
ref_img_sizes, img_sizes,
)
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, transformer_options={}):
batch_size = len(hidden_states)
hidden_states = self.x_embedder(hidden_states)
@@ -405,17 +405,17 @@ class OmniGen2Transformer2DModel(nn.Module):
shift += ref_img_len
for layer in self.noise_refiner:
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb, transformer_options=transformer_options)
if ref_image_hidden_states is not None:
for layer in self.ref_image_refiner:
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb, transformer_options=transformer_options)
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
return hidden_states
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs):
B, C, H, W = x.shape
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
_, _, H_padded, W_padded = hidden_states.shape
@@ -444,7 +444,7 @@ class OmniGen2Transformer2DModel(nn.Module):
)
for layer in self.context_refiner:
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options)
img_len = hidden_states.shape[1]
combined_img_hidden_states = self.img_patch_embed_and_refine(
@@ -453,13 +453,14 @@ class OmniGen2Transformer2DModel(nn.Module):
noise_rotary_emb, ref_img_rotary_emb,
l_effective_ref_img_len, l_effective_img_len,
temb,
transformer_options=transformer_options,
)
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
attention_mask = None
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb, transformer_options=transformer_options)
hidden_states = self.norm_out(hidden_states, temb)

View File

@@ -0,0 +1,77 @@
import torch
import math
from .model import QwenImageTransformer2DModel
class QwenImageControlNetModel(QwenImageTransformer2DModel):
def __init__(
self,
extra_condition_channels=0,
dtype=None,
device=None,
operations=None,
**kwargs
):
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
self.main_model_double = 60
# controlnet_blocks
self.controlnet_blocks = torch.nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
self.controlnet_blocks.append(operations.Linear(self.inner_dim, self.inner_dim, device=device, dtype=dtype))
self.controlnet_x_embedder = operations.Linear(self.in_channels + extra_condition_channels, self.inner_dim, device=device, dtype=dtype)
def forward(
self,
x,
timesteps,
context,
attention_mask=None,
guidance: torch.Tensor = None,
ref_latents=None,
hint=None,
transformer_options={},
**kwargs
):
timestep = timesteps
encoder_hidden_states = context
encoder_hidden_states_mask = attention_mask
hidden_states, img_ids, orig_shape = self.process_img(x)
hint, _, _ = self.process_img(hint)
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if guidance is not None:
guidance = guidance * 1000
temb = (
self.time_text_embed(timestep, hidden_states)
if guidance is None
else self.time_text_embed(timestep, guidance, hidden_states)
)
repeat = math.ceil(self.main_model_double / len(self.controlnet_blocks))
controlnet_block_samples = ()
for i, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
controlnet_block_samples = controlnet_block_samples + (self.controlnet_blocks[i](hidden_states),) * repeat
return {"input": controlnet_block_samples[:self.main_model_double]}

View File

@@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
import comfy.ldm.common_dit
import comfy.patcher_extension
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
@@ -131,6 +132,7 @@ class Attention(nn.Module):
encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
seq_txt = encoder_hidden_states.shape[1]
@@ -158,7 +160,7 @@ class Attention(nn.Module):
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
@@ -214,9 +216,9 @@ class QwenImageTransformerBlock(nn.Module):
operations=operations,
)
def _modulate(self, x, mod_params):
shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
def forward(
self,
@@ -225,6 +227,7 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb)
txt_mod_params = self.txt_mod(temb)
@@ -241,6 +244,7 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states=txt_modulated,
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
hidden_states = hidden_states + img_gate1 * img_attn_output
@@ -248,11 +252,11 @@ class QwenImageTransformerBlock(nn.Module):
img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
txt_normed2 = self.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
return encoder_hidden_states, hidden_states
@@ -275,7 +279,7 @@ class LastLayer(nn.Module):
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(conditioning_embedding))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :])
return x
@@ -293,6 +297,7 @@ class QwenImageTransformer2DModel(nn.Module):
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
image_model=None,
final_layer=True,
dtype=None,
device=None,
operations=None,
@@ -300,6 +305,7 @@ class QwenImageTransformer2DModel(nn.Module):
super().__init__()
self.dtype = dtype
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
@@ -329,9 +335,9 @@ class QwenImageTransformer2DModel(nn.Module):
for _ in range(num_layers)
])
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
self.gradient_checkpointing = False
if final_layer:
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
def process_img(self, x, index=0, h_offset=0, w_offset=0):
bs, c, t, h, w = x.shape
@@ -353,7 +359,14 @@ class QwenImageTransformer2DModel(nn.Module):
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
def forward(
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
def _forward(
self,
x,
timesteps,
@@ -362,6 +375,7 @@ class QwenImageTransformer2DModel(nn.Module):
guidance: torch.Tensor = None,
ref_latents=None,
transformer_options={},
control=None,
**kwargs
):
timestep = timesteps
@@ -416,15 +430,16 @@ class QwenImageTransformer2DModel(nn.Module):
)
patches_replace = transformer_options.get("patches_replace", {})
patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"])
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
else:
@@ -434,8 +449,22 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
hidden_states[:, :add.shape[1]] += add
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)

View File

@@ -4,13 +4,14 @@ import math
import torch
import torch.nn as nn
from einops import repeat
from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.ldm.common_dit
import comfy.model_management
import comfy.patcher_extension
def sinusoidal_embedding_1d(dim, position):
@@ -51,7 +52,7 @@ class WanSelfAttention(nn.Module):
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, freqs):
def forward(self, x, freqs, transformer_options={}):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
@@ -74,6 +75,7 @@ class WanSelfAttention(nn.Module):
k.view(b, s, n * d),
v,
heads=self.num_heads,
transformer_options=transformer_options,
)
x = self.o(x)
@@ -82,7 +84,7 @@ class WanSelfAttention(nn.Module):
class WanT2VCrossAttention(WanSelfAttention):
def forward(self, x, context, **kwargs):
def forward(self, x, context, transformer_options={}, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@@ -94,7 +96,7 @@ class WanT2VCrossAttention(WanSelfAttention):
v = self.v(context)
# compute attention
x = optimized_attention(q, k, v, heads=self.num_heads)
x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
x = self.o(x)
return x
@@ -115,7 +117,7 @@ class WanI2VCrossAttention(WanSelfAttention):
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, context, context_img_len):
def forward(self, x, context, context_img_len, transformer_options={}):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@@ -130,9 +132,9 @@ class WanI2VCrossAttention(WanSelfAttention):
v = self.v(context)
k_img = self.norm_k_img(self.k_img(context_img))
v_img = self.v_img(context_img)
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads)
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options)
# compute attention
x = optimized_attention(q, k, v, heads=self.num_heads)
x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
# output
x = x + img_x
@@ -148,11 +150,14 @@ WAN_CROSSATTENTION_CLASSES = {
def repeat_e(e, x):
repeats = 1
if e.shape[1] > 1:
repeats = x.shape[1] // e.shape[1]
if e.size(1) > 1:
repeats = x.size(1) // e.size(1)
if repeats == 1:
return e
return torch.repeat_interleave(e, repeats, dim=1)
if repeats * e.size(1) == x.size(1):
return torch.repeat_interleave(e, repeats, dim=1)
else:
return torch.repeat_interleave(e, repeats + 1, dim=1)[:, :x.size(1)]
class WanAttentionBlock(nn.Module):
@@ -202,6 +207,7 @@ class WanAttentionBlock(nn.Module):
freqs,
context,
context_img_len=257,
transformer_options={},
):
r"""
Args:
@@ -219,15 +225,15 @@ class WanAttentionBlock(nn.Module):
# self-attention
y = self.self_attn(
self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x),
freqs)
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, transformer_options=transformer_options)
x = x + y * repeat_e(e[2], x)
x = torch.addcmul(x, y, repeat_e(e[2], x))
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x))
x = x + y * repeat_e(e[5], x)
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
return x
@@ -342,7 +348,7 @@ class Head(nn.Module):
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x)))
x = (self.head(torch.addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x))))
return x
@@ -555,12 +561,12 @@ class WanModel(torch.nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
# head
x = self.head(x, e)
@@ -572,30 +578,49 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes)
return x
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
if steps_t is None:
steps_t = t_len
if steps_h is None:
steps_h = h_len
if steps_w is None:
steps_w = w_len
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder(img_ids).movedim(1, 2)
return freqs
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, clip_fea, time_dim_concat, transformer_options, **kwargs)
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
t_len = t
if time_dim_concat is not None:
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
x = torch.cat([x, time_dim_concat], dim=2)
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
t_len = x.shape[2]
if self.ref_conv is not None and "reference_latent" in kwargs:
t_len += 1
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
freqs = self.rope_embedder(img_ids).movedim(1, 2)
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):
@@ -719,17 +744,17 @@ class VaceWanModel(WanModel):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
ii = self.vace_layers_mapping.get(i, None)
if ii is not None:
for iii in range(len(c)):
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
x += c_skip * vace_strength[iii]
del c_skip
# head
@@ -818,12 +843,12 @@ class CameraWanModel(WanModel):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
# head
x = self.head(x, e)
@@ -831,3 +856,468 @@ class CameraWanModel(WanModel):
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
class CausalConv1d(nn.Module):
def __init__(self,
chan_in,
chan_out,
kernel_size=3,
stride=1,
dilation=1,
pad_mode='replicate',
operations=None,
**kwargs):
super().__init__()
self.pad_mode = pad_mode
padding = (kernel_size - 1, 0) # T
self.time_causal_padding = padding
self.conv = operations.Conv1d(
chan_in,
chan_out,
kernel_size,
stride=stride,
dilation=dilation,
**kwargs)
def forward(self, x):
x = torch.nn.functional.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class MotionEncoder_tc(nn.Module):
def __init__(self,
in_dim: int,
hidden_dim: int,
num_heads=int,
need_global=True,
dtype=None,
device=None,
operations=None,):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.num_heads = num_heads
self.need_global = need_global
self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1, operations=operations, **factory_kwargs)
if need_global:
self.conv1_global = CausalConv1d(
in_dim, hidden_dim // 4, 3, stride=1, operations=operations, **factory_kwargs)
self.norm1 = operations.LayerNorm(
hidden_dim // 4,
elementwise_affine=False,
eps=1e-6,
**factory_kwargs)
self.act = nn.SiLU()
self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2, operations=operations, **factory_kwargs)
self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2, operations=operations, **factory_kwargs)
if need_global:
self.final_linear = operations.Linear(hidden_dim, hidden_dim, **factory_kwargs)
self.norm1 = operations.LayerNorm(
hidden_dim // 4,
elementwise_affine=False,
eps=1e-6,
**factory_kwargs)
self.norm2 = operations.LayerNorm(
hidden_dim // 2,
elementwise_affine=False,
eps=1e-6,
**factory_kwargs)
self.norm3 = operations.LayerNorm(
hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs))
def forward(self, x):
x = rearrange(x, 'b t c -> b c t')
x_ori = x.clone()
b, c, t = x.shape
x = self.conv1_local(x)
x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv2(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv3(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm3(x)
x = self.act(x)
x = rearrange(x, '(b n) t c -> b t n c', b=b)
padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1)
x = torch.cat([x, padding], dim=-2)
x_local = x.clone()
if not self.need_global:
return x_local
x = self.conv1_global(x_ori)
x = rearrange(x, 'b c t -> b t c')
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv2(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv3(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm3(x)
x = self.act(x)
x = self.final_linear(x)
x = rearrange(x, '(b n) t c -> b t n c', b=b)
return x, x_local
class CausalAudioEncoder(nn.Module):
def __init__(self,
dim=5120,
num_layers=25,
out_dim=2048,
video_rate=8,
num_token=4,
need_global=False,
dtype=None,
device=None,
operations=None):
super().__init__()
self.encoder = MotionEncoder_tc(
in_dim=dim,
hidden_dim=out_dim,
num_heads=num_token,
need_global=need_global, dtype=dtype, device=device, operations=operations)
weight = torch.empty((1, num_layers, 1, 1), dtype=dtype, device=device)
self.weights = torch.nn.Parameter(weight)
self.act = torch.nn.SiLU()
def forward(self, features):
# features B * num_layers * dim * video_length
weights = self.act(comfy.model_management.cast_to(self.weights, dtype=features.dtype, device=features.device))
weights_sum = weights.sum(dim=1, keepdims=True)
weighted_feat = ((features * weights) / weights_sum).sum(
dim=1) # b dim f
weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
res = self.encoder(weighted_feat) # b f n dim
return res # b f n dim
class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim, output_dim=None, norm_elementwise_affine=False, norm_eps=1e-5, dtype=None, device=None, operations=None):
super().__init__()
output_dim = output_dim or embedding_dim * 2
self.silu = nn.SiLU()
self.linear = operations.Linear(embedding_dim, output_dim, dtype=dtype, device=device)
self.norm = operations.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine, dtype=dtype, device=device)
def forward(self, x, temb):
temb = self.linear(self.silu(temb))
shift, scale = temb.chunk(2, dim=1)
shift = shift[:, None, :]
scale = scale[:, None, :]
x = self.norm(x) * (1 + scale) + shift
return x
class AudioInjector_WAN(nn.Module):
def __init__(self,
dim=2048,
num_heads=32,
inject_layer=[0, 27],
root_net=None,
enable_adain=False,
adain_dim=2048,
adain_mode=None,
dtype=None,
device=None,
operations=None):
super().__init__()
self.enable_adain = enable_adain
self.adain_mode = adain_mode
self.injected_block_id = {}
audio_injector_id = 0
for inject_id in inject_layer:
self.injected_block_id[inject_id] = audio_injector_id
audio_injector_id += 1
self.injector = nn.ModuleList([
WanT2VCrossAttention(
dim=dim,
num_heads=num_heads,
qk_norm=True, operation_settings={"operations": operations, "device": device, "dtype": dtype}
) for _ in range(audio_injector_id)
])
self.injector_pre_norm_feat = nn.ModuleList([
operations.LayerNorm(
dim,
elementwise_affine=False,
eps=1e-6, dtype=dtype, device=device
) for _ in range(audio_injector_id)
])
self.injector_pre_norm_vec = nn.ModuleList([
operations.LayerNorm(
dim,
elementwise_affine=False,
eps=1e-6, dtype=dtype, device=device
) for _ in range(audio_injector_id)
])
if enable_adain:
self.injector_adain_layers = nn.ModuleList([
AdaLayerNorm(
output_dim=dim * 2, embedding_dim=adain_dim, dtype=dtype, device=device, operations=operations)
for _ in range(audio_injector_id)
])
if adain_mode != "attn_norm":
self.injector_adain_output_layers = nn.ModuleList(
[operations.Linear(dim, dim, dtype=dtype, device=device) for _ in range(audio_injector_id)])
def forward(self, x, block_id, audio_emb, audio_emb_global, seq_len):
audio_attn_id = self.injected_block_id.get(block_id, None)
if audio_attn_id is None:
return x
num_frames = audio_emb.shape[1]
input_hidden_states = rearrange(x[:, :seq_len], "b (t n) c -> (b t) n c", t=num_frames)
if self.enable_adain and self.adain_mode == "attn_norm":
audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c")
adain_hidden_states = self.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0])
attn_hidden_states = adain_hidden_states
else:
attn_hidden_states = self.injector_pre_norm_feat[audio_attn_id](input_hidden_states)
audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
attn_audio_emb = audio_emb
residual_out = self.injector[audio_attn_id](x=attn_hidden_states, context=attn_audio_emb)
residual_out = rearrange(
residual_out, "(b t) n c -> b (t n) c", t=num_frames)
x[:, :seq_len] = x[:, :seq_len] + residual_out
return x
class FramePackMotioner(nn.Module):
def __init__(
self,
inner_dim=1024,
num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design
zip_frame_buckets=[
1, 2, 16
], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames
drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion
dtype=None,
device=None,
operations=None):
super().__init__()
self.proj = operations.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2), dtype=dtype, device=device)
self.proj_2x = operations.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4), dtype=dtype, device=device)
self.proj_4x = operations.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8), dtype=dtype, device=device)
self.zip_frame_buckets = zip_frame_buckets
self.inner_dim = inner_dim
self.num_heads = num_heads
self.drop_mode = drop_mode
def forward(self, motion_latents, rope_embedder, add_last_motion=2):
lat_height, lat_width = motion_latents.shape[3], motion_latents.shape[4]
padd_lat = torch.zeros(motion_latents.shape[0], 16, sum(self.zip_frame_buckets), lat_height, lat_width).to(device=motion_latents.device, dtype=motion_latents.dtype)
overlap_frame = min(padd_lat.shape[2], motion_latents.shape[2])
if overlap_frame > 0:
padd_lat[:, :, -overlap_frame:] = motion_latents[:, :, -overlap_frame:]
if add_last_motion < 2 and self.drop_mode != "drop":
zero_end_frame = sum(self.zip_frame_buckets[:len(self.zip_frame_buckets) - add_last_motion - 1])
padd_lat[:, :, -zero_end_frame:] = 0
clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -sum(self.zip_frame_buckets):, :, :].split(self.zip_frame_buckets[::-1], dim=2) # 16, 2 ,1
# patchfy
clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2)
clean_latents_2x = self.proj_2x(clean_latents_2x)
l_2x_shape = clean_latents_2x.shape
clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
clean_latents_4x = self.proj_4x(clean_latents_4x)
l_4x_shape = clean_latents_4x.shape
clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
if add_last_motion < 2 and self.drop_mode == "drop":
clean_latents_post = clean_latents_post[:, :
0] if add_last_motion < 2 else clean_latents_post
clean_latents_2x = clean_latents_2x[:, :
0] if add_last_motion < 1 else clean_latents_2x
motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
rope_post = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-1, device=motion_latents.device, dtype=motion_latents.dtype)
rope_2x = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-3, steps_h=l_2x_shape[-2], steps_w=l_2x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype)
rope_4x = rope_embedder.rope_encode(4, lat_height, lat_width, t_start=-19, steps_h=l_4x_shape[-2], steps_w=l_4x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype)
rope = torch.cat([rope_post, rope_2x, rope_4x], dim=1)
return motion_lat, rope
class WanModel_S2V(WanModel):
def __init__(self,
model_type='s2v',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
audio_dim=1024,
num_audio_token=4,
enable_adain=True,
cond_dim=16,
audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
adain_mode="attn_norm",
framepack_drop_mode="padd",
image_model=None,
device=None,
dtype=None,
operations=None,
):
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations)
self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype)
self.casual_audio_encoder = CausalAudioEncoder(
dim=audio_dim,
out_dim=self.dim,
num_token=num_audio_token,
need_global=enable_adain, dtype=dtype, device=device, operations=operations)
if cond_dim > 0:
self.cond_encoder = operations.Conv3d(
cond_dim,
self.dim,
kernel_size=self.patch_size,
stride=self.patch_size, device=device, dtype=dtype)
self.audio_injector = AudioInjector_WAN(
dim=self.dim,
num_heads=self.num_heads,
inject_layer=audio_inject_layers,
root_net=self,
enable_adain=enable_adain,
adain_dim=self.dim,
adain_mode=adain_mode,
dtype=dtype, device=device, operations=operations
)
self.frame_packer = FramePackMotioner(
inner_dim=self.dim,
num_heads=self.num_heads,
zip_frame_buckets=[1, 2, 16],
drop_mode=framepack_drop_mode,
dtype=dtype, device=device, operations=operations)
def forward_orig(
self,
x,
t,
context,
audio_embed=None,
reference_latent=None,
control_video=None,
reference_motion=None,
clip_fea=None,
freqs=None,
transformer_options={},
**kwargs,
):
if audio_embed is not None:
num_embeds = x.shape[-3] * 4
audio_emb_global, audio_emb = self.casual_audio_encoder(audio_embed[:, :, :, :num_embeds])
else:
audio_emb = None
# embeddings
bs, _, time, height, width = x.shape
x = self.patch_embedding(x.float()).to(x.dtype)
if control_video is not None:
x = x + self.cond_encoder(control_video)
if t.ndim == 1:
t = t.unsqueeze(1).repeat(1, x.shape[2])
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
seq_len = x.size(1)
cond_mask_weight = comfy.model_management.cast_to(self.trainable_cond_mask.weight, dtype=x.dtype, device=x.device).unsqueeze(1).unsqueeze(1)
x = x + cond_mask_weight[0]
if reference_latent is not None:
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
ref = ref.flatten(2).transpose(1, 2)
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=max(30, time + 9), device=x.device, dtype=x.dtype)
ref = ref + cond_mask_weight[1]
x = torch.cat([x, ref], dim=1)
freqs = torch.cat([freqs, freqs_ref], dim=1)
t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1)
del ref, freqs_ref
if reference_motion is not None:
motion_encoded, freqs_motion = self.frame_packer(reference_motion, self)
motion_encoded = motion_encoded + cond_mask_weight[2]
x = torch.cat([x, motion_encoded], dim=1)
freqs = torch.cat([freqs, freqs_motion], dim=1)
t = torch.repeat_interleave(t, 2, dim=1)
t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1)
del motion_encoded, freqs_motion
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
e = e.reshape(t.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
# context
context = self.text_embedding(context)
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context)
if audio_emb is not None:
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x

View File

@@ -150,6 +150,7 @@ class BaseModel(torch.nn.Module):
logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor
self.memory_usage_factor_conds = ()
self.memory_usage_shape_process = {}
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@@ -350,8 +351,15 @@ class BaseModel(torch.nn.Module):
input_shapes = [input_shape]
for c in self.memory_usage_factor_conds:
shape = cond_shapes.get(c, None)
if shape is not None and len(shape) > 0:
input_shapes += shape
if shape is not None:
if c in self.memory_usage_shape_process:
out = []
for s in shape:
out.append(self.memory_usage_shape_process[c](s))
shape = out
if len(shape) > 0:
input_shapes += shape
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
@@ -1102,9 +1110,10 @@ class WAN21(BaseModel):
shape_image[1] = extra_channels
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
else:
latent_dim = self.latent_format.latent_channels
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
for i in range(0, image.shape[1], 16):
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
for i in range(0, image.shape[1], latent_dim):
image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim])
image = utils.resize_to_batch_size(image, noise.shape[0])
if extra_channels != image.shape[1] + 4:
@@ -1201,18 +1210,50 @@ class WAN21_Camera(WAN21):
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
return out
class WAN22(BaseModel):
class WAN22_S2V(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
self.memory_usage_factor_conds = ("reference_latent", "reference_motion")
self.memory_usage_shape_process = {"reference_motion": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
audio_embed = kwargs.get("audio_embed", None)
if audio_embed is not None:
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
reference_latents = kwargs.get("reference_latents", None)
if reference_latents is not None:
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
reference_motion = kwargs.get("reference_motion", None)
if reference_motion is not None:
out['reference_motion'] = comfy.conds.CONDRegular(self.process_latent_in(reference_motion))
control_video = kwargs.get("control_video", None)
if control_video is not None:
out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video))
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
reference_motion = kwargs.get("reference_motion", None)
if reference_motion is not None:
out['reference_motion'] = reference_motion.shape
return out
class WAN22(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
denoise_mask = kwargs.get("denoise_mask", None)
if denoise_mask is not None:
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
return out

View File

@@ -368,6 +368,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "camera"
else:
dit_config["model_type"] = "camera_2.2"
elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "s2v"
else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"
@@ -492,6 +494,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
dit_config = {}
dit_config["image_model"] = "qwen_image"
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:

View File

@@ -593,7 +593,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
else:
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
models = set(models)
models_temp = set()
for m in models:
models_temp.add(m)
for mm in m.model_patches_models():
models_temp.add(mm)
models = models_temp
models_to_load = []

View File

@@ -430,6 +430,9 @@ class ModelPatcher:
def set_model_forward_timestep_embed_patch(self, patch):
self.set_model_patch(patch, "forward_timestep_embed_patch")
def set_model_double_block_patch(self, patch):
self.set_model_patch(patch, "double_block")
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
@@ -486,6 +489,30 @@ class ModelPatcher:
if hasattr(wrap_func, "to"):
self.model_options["model_function_wrapper"] = wrap_func.to(device)
def model_patches_models(self):
to = self.model_options["transformer_options"]
models = []
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "models"):
models += patch_list[i].models()
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "models"):
models += patch_list[k].models()
if "model_function_wrapper" in self.model_options:
wrap_func = self.model_options["model_function_wrapper"]
if hasattr(wrap_func, "models"):
models += wrap_func.models()
return models
def model_dtype(self):
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()

View File

@@ -50,6 +50,7 @@ class WrappersMP:
OUTER_SAMPLE = "outer_sample"
PREPARE_SAMPLING = "prepare_sampling"
SAMPLER_SAMPLE = "sampler_sample"
PREDICT_NOISE = "predict_noise"
CALC_COND_BATCH = "calc_cond_batch"
APPLY_MODEL = "apply_model"
DIFFUSION_MODEL = "diffusion_model"

21
comfy/samplers.py Normal file → Executable file
View File

@@ -17,6 +17,7 @@ import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
import comfy.context_windows
import comfy.utils
import scipy.stats
import numpy
@@ -61,7 +62,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
assert (mask.shape[1:] == x_in.shape[2:])
# assert (mask.shape[1:] == x_in.shape[2:])
mask = mask[:input_x.shape[0]]
if area is not None:
@@ -69,7 +70,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
mask = mask.narrow(i + 1, area[len(dims) + i], area[i])
mask = mask * mask_strength
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
mask = mask.unsqueeze(1).repeat((input_x.shape[0] // mask.shape[0], input_x.shape[1]) + (1, ) * (mask.ndim - 1))
else:
mask = torch.ones_like(input_x)
mult = mask * strength
@@ -553,7 +554,10 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
if len(mask.shape) == len(dims):
mask = mask.unsqueeze(0)
if mask.shape[1:] != dims:
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1)
if mask.ndim < 4:
mask = comfy.utils.common_upscale(mask.unsqueeze(1), dims[-1], dims[-2], 'bilinear', 'none').squeeze(1)
else:
mask = comfy.utils.common_upscale(mask, dims[-1], dims[-2], 'bilinear', 'none')
if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
@@ -725,7 +729,7 @@ class Sampler:
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
@@ -953,7 +957,14 @@ class CFGGuider:
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
def __call__(self, *args, **kwargs):
return self.predict_noise(*args, **kwargs)
return self.outer_predict_noise(*args, **kwargs)
def outer_predict_noise(self, x, timestep, model_options={}, seed=None):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self.predict_noise,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, self.model_options, is_model_options=True)
).execute(x, timestep, model_options, seed)
def predict_noise(self, x, timestep, model_options={}, seed=None):
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)

View File

@@ -700,7 +700,7 @@ class Flux(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.Flux
memory_usage_factor = 2.8
memory_usage_factor = 3.1 # TODO: debug why flux mem usage is so weird on windows.
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
@@ -1072,6 +1072,19 @@ class WAN21_Vace(WAN21_T2V):
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
return out
class WAN22_S2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "s2v",
}
def __init__(self, unet_config):
super().__init__(unet_config)
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN22_S2V(self, device=device)
return out
class WAN22_T2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@@ -1272,6 +1285,6 @@ class QwenImage(supported_models_base.BASE):
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models += [SVD_img2vid]

View File

@@ -97,6 +97,9 @@ class LoKrAdapter(WeightAdapterBase):
(mat1, mat2, alpha, None, None, None, None, None, None)
)
def to_train(self):
return LokrDiff(self.weights)
@classmethod
def load(
cls,

View File

@@ -8,6 +8,7 @@ import av
import io
import json
import numpy as np
import math
import torch
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
@@ -282,8 +283,6 @@ class VideoFromComponents(VideoInput):
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
audio_stream.sample_rate = audio_sample_rate
audio_stream.format = 'fltp'
# Encode video
for i, frame in enumerate(self.__components.images):
@@ -298,27 +297,12 @@ class VideoFromComponents(VideoInput):
output.mux(packet)
if audio_stream and self.__components.audio:
# Encode audio
samples_per_frame = int(audio_sample_rate / frame_rate)
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
for i in range(num_frames):
start = i * samples_per_frame
end = start + samples_per_frame
# TODO(Feature) - Add support for stereo audio
chunk = (
self.__components.audio["waveform"][0, 0, start:end]
.unsqueeze(0)
.contiguous()
.numpy()
)
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
audio_frame.sample_rate = audio_sample_rate
audio_frame.pts = i * samples_per_frame
for packet in audio_stream.encode(audio_frame):
output.mux(packet)
# Flush audio
for packet in audio_stream.encode(None):
output.mux(packet)
waveform = self.__components.audio['waveform']
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
frame.sample_rate = audio_sample_rate
frame.pts = 0
output.mux(audio_stream.encode(frame))
# Flush encoder
output.mux(audio_stream.encode(None))

View File

@@ -726,6 +726,18 @@ class SEGS(ComfyTypeIO):
class AnyType(ComfyTypeIO):
Type = Any
@comfytype(io_type="MODEL_PATCH")
class MODEL_PATCH(ComfyTypeIO):
Type = Any
@comfytype(io_type="AUDIO_ENCODER")
class AudioEncoder(ComfyTypeIO):
Type = Any
@comfytype(io_type="AUDIO_ENCODER_OUTPUT")
class AudioEncoderOutput(ComfyTypeIO):
Type = Any
@comfytype(io_type="COMFY_MULTITYPED_V3")
class MultiType:
Type = Any
@@ -1580,6 +1592,7 @@ class _IO:
Model = Model
ClipVision = ClipVision
ClipVisionOutput = ClipVisionOutput
AudioEncoderOutput = AudioEncoderOutput
StyleModel = StyleModel
Gligen = Gligen
UpscaleModel = UpscaleModel

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from typing import List, Optional
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
from pydantic import BaseModel
class GeminiImageGenerationConfig(GeminiGenerationConfig):
responseModalities: Optional[List[str]] = None
class GeminiImageGenerateContentRequest(BaseModel):
contents: List[GeminiContent]
generationConfig: Optional[GeminiImageGenerationConfig] = None
safetySettings: Optional[List[GeminiSafetySetting]] = None
systemInstruction: Optional[GeminiSystemInstructionContent] = None
tools: Optional[List[GeminiTool]] = None
videoMetadata: Optional[GeminiVideoMetadata] = None

View File

@@ -4,8 +4,12 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
"""
from __future__ import annotations
import json
import time
import os
import uuid
import base64
from io import BytesIO
from enum import Enum
from typing import Optional, Literal
@@ -22,6 +26,7 @@ from comfy_api_nodes.apis import (
GeminiPart,
GeminiMimeType,
)
from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
@@ -32,6 +37,7 @@ from comfy_api_nodes.apinode_utils import (
audio_to_base64_string,
video_to_base64_string,
tensor_to_base64_string,
bytesio_to_image_tensor,
)
@@ -50,6 +56,14 @@ class GeminiModel(str, Enum):
gemini_2_5_flash = "gemini-2.5-flash"
class GeminiImageModel(str, Enum):
"""
Gemini Image Model Names allowed by comfy-api
"""
gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview"
def get_gemini_endpoint(
model: GeminiModel,
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
@@ -72,6 +86,135 @@ def get_gemini_endpoint(
)
def get_gemini_image_endpoint(
model: GeminiImageModel,
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
"""
Get the API endpoint for a given Gemini model.
Args:
model: The Gemini model to use, either as enum or string value.
Returns:
ApiEndpoint configured for the specific Gemini model.
"""
if isinstance(model, str):
model = GeminiImageModel(model)
return ApiEndpoint(
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
method=HttpMethod.POST,
request_model=GeminiImageGenerateContentRequest,
response_model=GeminiGenerateContentResponse,
)
def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
"""
Convert image tensor input to Gemini API compatible parts.
Args:
image_input: Batch of image tensors from ComfyUI.
Returns:
List of GeminiPart objects containing the encoded images.
"""
image_parts: list[GeminiPart] = []
for image_index in range(image_input.shape[0]):
image_as_b64 = tensor_to_base64_string(
image_input[image_index].unsqueeze(0)
)
image_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.image_png,
data=image_as_b64,
)
)
)
return image_parts
def create_text_part(text: str) -> GeminiPart:
"""
Create a text part for the Gemini API request.
Args:
text: The text content to include in the request.
Returns:
A GeminiPart object with the text content.
"""
return GeminiPart(text=text)
def get_parts_from_response(
response: GeminiGenerateContentResponse
) -> list[GeminiPart]:
"""
Extract all parts from the Gemini API response.
Args:
response: The API response from Gemini.
Returns:
List of response parts from the first candidate.
"""
return response.candidates[0].content.parts
def get_parts_by_type(
response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
) -> list[GeminiPart]:
"""
Filter response parts by their type.
Args:
response: The API response from Gemini.
part_type: Type of parts to extract ("text" or a MIME type).
Returns:
List of response parts matching the requested type.
"""
parts = []
for part in get_parts_from_response(response):
if part_type == "text" and hasattr(part, "text") and part.text:
parts.append(part)
elif (
hasattr(part, "inlineData")
and part.inlineData
and part.inlineData.mimeType == part_type
):
parts.append(part)
# Skip parts that don't match the requested type
return parts
def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
"""
Extract and concatenate all text parts from the response.
Args:
response: The API response from Gemini.
Returns:
Combined text from all text parts in the response.
"""
parts = get_parts_by_type(response, "text")
return "\n".join([part.text for part in parts])
def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
image_tensors: list[torch.Tensor] = []
parts = get_parts_by_type(response, "image/png")
for part in parts:
image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
image_tensors.append(returned_image)
if len(image_tensors) == 0:
return torch.zeros((1,1024,1024,4))
return torch.cat(image_tensors, dim=0)
class GeminiNode(ComfyNodeABC):
"""
Node to generate text responses from a Gemini model.
@@ -156,59 +299,6 @@ class GeminiNode(ComfyNodeABC):
CATEGORY = "api node/text/Gemini"
API_NODE = True
def get_parts_from_response(
self, response: GeminiGenerateContentResponse
) -> list[GeminiPart]:
"""
Extract all parts from the Gemini API response.
Args:
response: The API response from Gemini.
Returns:
List of response parts from the first candidate.
"""
return response.candidates[0].content.parts
def get_parts_by_type(
self, response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
) -> list[GeminiPart]:
"""
Filter response parts by their type.
Args:
response: The API response from Gemini.
part_type: Type of parts to extract ("text" or a MIME type).
Returns:
List of response parts matching the requested type.
"""
parts = []
for part in self.get_parts_from_response(response):
if part_type == "text" and hasattr(part, "text") and part.text:
parts.append(part)
elif (
hasattr(part, "inlineData")
and part.inlineData
and part.inlineData.mimeType == part_type
):
parts.append(part)
# Skip parts that don't match the requested type
return parts
def get_text_from_response(self, response: GeminiGenerateContentResponse) -> str:
"""
Extract and concatenate all text parts from the response.
Args:
response: The API response from Gemini.
Returns:
Combined text from all text parts in the response.
"""
parts = self.get_parts_by_type(response, "text")
return "\n".join([part.text for part in parts])
def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]:
"""
Convert video input to Gemini API compatible parts.
@@ -268,43 +358,6 @@ class GeminiNode(ComfyNodeABC):
)
return audio_parts
def create_image_parts(self, image_input: torch.Tensor) -> list[GeminiPart]:
"""
Convert image tensor input to Gemini API compatible parts.
Args:
image_input: Batch of image tensors from ComfyUI.
Returns:
List of GeminiPart objects containing the encoded images.
"""
image_parts: list[GeminiPart] = []
for image_index in range(image_input.shape[0]):
image_as_b64 = tensor_to_base64_string(
image_input[image_index].unsqueeze(0)
)
image_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.image_png,
data=image_as_b64,
)
)
)
return image_parts
def create_text_part(self, text: str) -> GeminiPart:
"""
Create a text part for the Gemini API request.
Args:
text: The text content to include in the request.
Returns:
A GeminiPart object with the text content.
"""
return GeminiPart(text=text)
async def api_call(
self,
prompt: str,
@@ -320,11 +373,11 @@ class GeminiNode(ComfyNodeABC):
validate_string(prompt, strip_whitespace=False)
# Create parts list with text prompt as the first part
parts: list[GeminiPart] = [self.create_text_part(prompt)]
parts: list[GeminiPart] = [create_text_part(prompt)]
# Add other modal parts
if images is not None:
image_parts = self.create_image_parts(images)
image_parts = create_image_parts(images)
parts.extend(image_parts)
if audio is not None:
parts.extend(self.create_audio_parts(audio))
@@ -348,9 +401,29 @@ class GeminiNode(ComfyNodeABC):
).execute()
# Get result output
output_text = self.get_text_from_response(response)
output_text = get_text_from_response(response)
if unique_id and output_text:
PromptServer.instance.send_progress_text(output_text, node_id=unique_id)
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
render_spec = {
"node_id": unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
[
{
"prompt": prompt,
"response": output_text,
"response_id": str(uuid.uuid4()),
"timestamp": time.time(),
}
]
),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)
return (output_text or "Empty response from Gemini model...",)
@@ -439,12 +512,162 @@ class GeminiInputFiles(ComfyNodeABC):
return (files,)
class GeminiImage(ComfyNodeABC):
"""
Node to generate text and image responses from a Gemini model.
This node allows users to interact with Google's Gemini AI models, providing
multimodal inputs (text, images, files) to generate coherent
text and image responses. The node works with the latest Gemini models, handling the
API communication and response parsing.
"""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Text prompt for generation",
},
),
"model": (
IO.COMBO,
{
"tooltip": "The Gemini model to use for generating responses.",
"options": [model.value for model in GeminiImageModel],
"default": GeminiImageModel.gemini_2_5_flash_image_preview.value,
},
),
"seed": (
IO.INT,
{
"default": 42,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
},
),
},
"optional": {
"images": (
IO.IMAGE,
{
"default": None,
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
},
),
"files": (
"GEMINI_INPUT_FILES",
{
"default": None,
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
},
),
# TODO: later we can add this parameter later
# "n": (
# IO.INT,
# {
# "default": 1,
# "min": 1,
# "max": 8,
# "step": 1,
# "display": "number",
# "tooltip": "How many images to generate",
# },
# ),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = (IO.IMAGE, IO.STRING)
FUNCTION = "api_call"
CATEGORY = "api node/image/Gemini"
DESCRIPTION = "Edit images synchronously via Google API."
API_NODE = True
async def api_call(
self,
prompt: str,
model: GeminiImageModel,
images: Optional[IO.IMAGE] = None,
files: Optional[list[GeminiPart]] = None,
n=1,
unique_id: Optional[str] = None,
**kwargs,
):
# Validate inputs
validate_string(prompt, strip_whitespace=True, min_length=1)
# Create parts list with text prompt as the first part
parts: list[GeminiPart] = [create_text_part(prompt)]
# Add other modal parts
if images is not None:
image_parts = create_image_parts(images)
parts.extend(image_parts)
if files is not None:
parts.extend(files)
response = await SynchronousOperation(
endpoint=get_gemini_image_endpoint(model),
request=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(
role="user",
parts=parts,
),
],
generationConfig=GeminiImageGenerationConfig(
responseModalities=["TEXT","IMAGE"]
)
),
auth_kwargs=kwargs,
).execute()
output_image = get_image_from_response(response)
output_text = get_text_from_response(response)
if unique_id and output_text:
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
render_spec = {
"node_id": unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
[
{
"prompt": prompt,
"response": output_text,
"response_id": str(uuid.uuid4()),
"timestamp": time.time(),
}
]
),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)
output_text = output_text or "Empty response from Gemini model..."
return (output_image, output_text,)
NODE_CLASS_MAPPINGS = {
"GeminiNode": GeminiNode,
"GeminiImageNode": GeminiImage,
"GeminiInputFiles": GeminiInputFiles,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"GeminiNode": "Google Gemini",
"GeminiImageNode": "Google Gemini Image",
"GeminiInputFiles": "Gemini Input Files",
}

View File

@@ -1,8 +1,8 @@
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
from inspect import cleandoc
from io import BytesIO
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from PIL import Image
import numpy as np
import io
import torch
from comfy_api_nodes.apis import (
IdeogramGenerateRequest,
@@ -246,90 +246,81 @@ def display_image_urls_on_node(image_urls, node_id):
PromptServer.instance.send_progress_text(urls_text, node_id)
class IdeogramV1(ComfyNodeABC):
"""
Generates images using the Ideogram V1 model.
"""
def __init__(self):
pass
class IdeogramV1(comfy_io.ComfyNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
def define_schema(cls):
return comfy_io.Schema(
node_id="IdeogramV1",
display_name="Ideogram V1",
category="api node/image/Ideogram",
description="Generates images using the Ideogram V1 model.",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
"turbo": (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
}
comfy_io.Boolean.Input(
"turbo",
default=False,
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
),
},
"optional": {
"aspect_ratio": (
IO.COMBO,
{
"options": list(V1_V2_RATIO_MAP.keys()),
"default": "1:1",
"tooltip": "The aspect ratio for image generation.",
},
comfy_io.Combo.Input(
"aspect_ratio",
options=list(V1_V2_RATIO_MAP.keys()),
default="1:1",
tooltip="The aspect ratio for image generation.",
optional=True,
),
"magic_prompt_option": (
IO.COMBO,
{
"options": ["AUTO", "ON", "OFF"],
"default": "AUTO",
"tooltip": "Determine if MagicPrompt should be used in generation",
},
comfy_io.Combo.Input(
"magic_prompt_option",
options=["AUTO", "ON", "OFF"],
default="AUTO",
tooltip="Determine if MagicPrompt should be used in generation",
optional=True,
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"step": 1,
"control_after_generate": True,
"display": "number",
},
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
optional=True,
),
"negative_prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Description of what to exclude from the image",
},
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Description of what to exclude from the image",
optional=True,
),
"num_images": (
IO.INT,
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
comfy_io.Int.Input(
"num_images",
default=1,
min=1,
max=8,
step=1,
display_mode=comfy_io.NumberDisplay.number,
optional=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
)
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "api_call"
CATEGORY = "api node/image/Ideogram"
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
async def api_call(
self,
@classmethod
async def execute(
cls,
prompt,
turbo=False,
aspect_ratio="1:1",
@@ -337,13 +328,15 @@ class IdeogramV1(ComfyNodeABC):
seed=0,
negative_prompt="",
num_images=1,
unique_id=None,
**kwargs,
):
# Determine the model based on turbo setting
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
model = "V_1_TURBO" if turbo else "V_1"
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
@@ -364,7 +357,7 @@ class IdeogramV1(ComfyNodeABC):
negative_prompt=negative_prompt if negative_prompt else None,
)
),
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response = await operation.execute()
@@ -377,93 +370,85 @@ class IdeogramV1(ComfyNodeABC):
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, unique_id)
return (await download_and_process_images(image_urls),)
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
class IdeogramV2(ComfyNodeABC):
"""
Generates images using the Ideogram V2 model.
"""
def __init__(self):
pass
class IdeogramV2(comfy_io.ComfyNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
def define_schema(cls):
return comfy_io.Schema(
node_id="IdeogramV2",
display_name="Ideogram V2",
category="api node/image/Ideogram",
description="Generates images using the Ideogram V2 model.",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
"turbo": (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
}
comfy_io.Boolean.Input(
"turbo",
default=False,
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
),
},
"optional": {
"aspect_ratio": (
IO.COMBO,
{
"options": list(V1_V2_RATIO_MAP.keys()),
"default": "1:1",
"tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
},
comfy_io.Combo.Input(
"aspect_ratio",
options=list(V1_V2_RATIO_MAP.keys()),
default="1:1",
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
optional=True,
),
"resolution": (
IO.COMBO,
{
"options": list(V1_V1_RES_MAP.keys()),
"default": "Auto",
"tooltip": "The resolution for image generation. If not set to AUTO, this overrides the aspect_ratio setting.",
},
comfy_io.Combo.Input(
"resolution",
options=list(V1_V1_RES_MAP.keys()),
default="Auto",
tooltip="The resolution for image generation. "
"If not set to AUTO, this overrides the aspect_ratio setting.",
optional=True,
),
"magic_prompt_option": (
IO.COMBO,
{
"options": ["AUTO", "ON", "OFF"],
"default": "AUTO",
"tooltip": "Determine if MagicPrompt should be used in generation",
},
comfy_io.Combo.Input(
"magic_prompt_option",
options=["AUTO", "ON", "OFF"],
default="AUTO",
tooltip="Determine if MagicPrompt should be used in generation",
optional=True,
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"step": 1,
"control_after_generate": True,
"display": "number",
},
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
optional=True,
),
"style_type": (
IO.COMBO,
{
"options": ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
"default": "NONE",
"tooltip": "Style type for generation (V2 only)",
},
comfy_io.Combo.Input(
"style_type",
options=["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
default="NONE",
tooltip="Style type for generation (V2 only)",
optional=True,
),
"negative_prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Description of what to exclude from the image",
},
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Description of what to exclude from the image",
optional=True,
),
"num_images": (
IO.INT,
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
comfy_io.Int.Input(
"num_images",
default=1,
min=1,
max=8,
step=1,
display_mode=comfy_io.NumberDisplay.number,
optional=True,
),
#"color_palette": (
# IO.STRING,
@@ -473,22 +458,20 @@ class IdeogramV2(ComfyNodeABC):
# "tooltip": "Color palette preset name or hex colors with weights",
# },
#),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
)
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "api_call"
CATEGORY = "api node/image/Ideogram"
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
async def api_call(
self,
@classmethod
async def execute(
cls,
prompt,
turbo=False,
aspect_ratio="1:1",
@@ -499,8 +482,6 @@ class IdeogramV2(ComfyNodeABC):
negative_prompt="",
num_images=1,
color_palette="",
unique_id=None,
**kwargs,
):
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
resolution = V1_V1_RES_MAP.get(resolution, None)
@@ -517,6 +498,10 @@ class IdeogramV2(ComfyNodeABC):
else:
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
@@ -540,7 +525,7 @@ class IdeogramV2(ComfyNodeABC):
color_palette=color_palette if color_palette else None,
)
),
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response = await operation.execute()
@@ -553,108 +538,99 @@ class IdeogramV2(ComfyNodeABC):
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, unique_id)
return (await download_and_process_images(image_urls),)
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
class IdeogramV3(ComfyNodeABC):
"""
Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask.
"""
def __init__(self):
pass
class IdeogramV3(comfy_io.ComfyNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation or editing",
},
def define_schema(cls):
return comfy_io.Schema(
node_id="IdeogramV3",
display_name="Ideogram V3",
category="api node/image/Ideogram",
description="Generates images using the Ideogram V3 model. "
"Supports both regular image generation from text prompts and image editing with mask.",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation or editing",
),
},
"optional": {
"image": (
IO.IMAGE,
{
"default": None,
"tooltip": "Optional reference image for image editing.",
},
comfy_io.Image.Input(
"image",
tooltip="Optional reference image for image editing.",
optional=True,
),
"mask": (
IO.MASK,
{
"default": None,
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
},
comfy_io.Mask.Input(
"mask",
tooltip="Optional mask for inpainting (white areas will be replaced)",
optional=True,
),
"aspect_ratio": (
IO.COMBO,
{
"options": list(V3_RATIO_MAP.keys()),
"default": "1:1",
"tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
},
comfy_io.Combo.Input(
"aspect_ratio",
options=list(V3_RATIO_MAP.keys()),
default="1:1",
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
optional=True,
),
"resolution": (
IO.COMBO,
{
"options": V3_RESOLUTIONS,
"default": "Auto",
"tooltip": "The resolution for image generation. If not set to Auto, this overrides the aspect_ratio setting.",
},
comfy_io.Combo.Input(
"resolution",
options=V3_RESOLUTIONS,
default="Auto",
tooltip="The resolution for image generation. "
"If not set to Auto, this overrides the aspect_ratio setting.",
optional=True,
),
"magic_prompt_option": (
IO.COMBO,
{
"options": ["AUTO", "ON", "OFF"],
"default": "AUTO",
"tooltip": "Determine if MagicPrompt should be used in generation",
},
comfy_io.Combo.Input(
"magic_prompt_option",
options=["AUTO", "ON", "OFF"],
default="AUTO",
tooltip="Determine if MagicPrompt should be used in generation",
optional=True,
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"step": 1,
"control_after_generate": True,
"display": "number",
},
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
optional=True,
),
"num_images": (
IO.INT,
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
comfy_io.Int.Input(
"num_images",
default=1,
min=1,
max=8,
step=1,
display_mode=comfy_io.NumberDisplay.number,
optional=True,
),
"rendering_speed": (
IO.COMBO,
{
"options": ["BALANCED", "TURBO", "QUALITY"],
"default": "BALANCED",
"tooltip": "Controls the trade-off between generation speed and quality",
},
comfy_io.Combo.Input(
"rendering_speed",
options=["BALANCED", "TURBO", "QUALITY"],
default="BALANCED",
tooltip="Controls the trade-off between generation speed and quality",
optional=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
)
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "api_call"
CATEGORY = "api node/image/Ideogram"
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
async def api_call(
self,
@classmethod
async def execute(
cls,
prompt,
image=None,
mask=None,
@@ -664,9 +640,11 @@ class IdeogramV3(ComfyNodeABC):
seed=0,
num_images=1,
rendering_speed="BALANCED",
unique_id=None,
**kwargs,
):
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# Check if both image and mask are provided for editing mode
if image is not None and mask is not None:
# Edit mode
@@ -686,7 +664,7 @@ class IdeogramV3(ComfyNodeABC):
# Process image
img_np = (input_tensor.numpy() * 255).astype(np.uint8)
img = Image.fromarray(img_np)
img_byte_arr = io.BytesIO()
img_byte_arr = BytesIO()
img.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
img_binary = img_byte_arr
@@ -695,7 +673,7 @@ class IdeogramV3(ComfyNodeABC):
# Process mask - white areas will be replaced
mask_np = (mask.squeeze().cpu().numpy() * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_np)
mask_byte_arr = io.BytesIO()
mask_byte_arr = BytesIO()
mask_img.save(mask_byte_arr, format="PNG")
mask_byte_arr.seek(0)
mask_binary = mask_byte_arr
@@ -729,7 +707,7 @@ class IdeogramV3(ComfyNodeABC):
"mask": mask_binary,
},
content_type="multipart/form-data",
auth_kwargs=kwargs,
auth_kwargs=auth,
)
elif image is not None or mask is not None:
@@ -770,7 +748,7 @@ class IdeogramV3(ComfyNodeABC):
response_model=IdeogramGenerateResponse,
),
request=gen_request,
auth_kwargs=kwargs,
auth_kwargs=auth,
)
# Execute the operation and process response
@@ -784,18 +762,18 @@ class IdeogramV3(ComfyNodeABC):
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, unique_id)
return (await download_and_process_images(image_urls),)
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
NODE_CLASS_MAPPINGS = {
"IdeogramV1": IdeogramV1,
"IdeogramV2": IdeogramV2,
"IdeogramV3": IdeogramV3,
}
class IdeogramExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
IdeogramV1,
IdeogramV2,
IdeogramV3,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"IdeogramV1": "Ideogram V1",
"IdeogramV2": "Ideogram V2",
"IdeogramV3": "Ideogram V3",
}
async def comfy_entrypoint() -> IdeogramExtension:
return IdeogramExtension()

View File

@@ -998,7 +998,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"OpenAIDalle2": "OpenAI DALL·E 2",
"OpenAIDalle3": "OpenAI DALL·E 3",
"OpenAIGPTImage1": "OpenAI GPT Image 1",
"OpenAIChatNode": "OpenAI Chat",
"OpenAIInputFiles": "OpenAI Chat Input Files",
"OpenAIChatConfig": "OpenAI Chat Advanced Options",
"OpenAIChatNode": "OpenAI ChatGPT",
"OpenAIInputFiles": "OpenAI ChatGPT Input Files",
"OpenAIChatConfig": "OpenAI ChatGPT Advanced Options",
}

View File

@@ -1,17 +1,18 @@
import io
import logging
import base64
import aiohttp
import torch
from io import BytesIO
from typing import Optional
from typing_extensions import override
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis import (
VeoGenVidRequest,
VeoGenVidResponse,
VeoGenVidPollRequest,
VeoGenVidPollResponse
VeoGenVidPollResponse,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
@@ -22,7 +23,7 @@ from comfy_api_nodes.apis.client import (
from comfy_api_nodes.apinode_utils import (
downscale_image_tensor,
tensor_to_base64_string
tensor_to_base64_string,
)
AVERAGE_DURATION_VIDEO_GEN = 32
@@ -50,7 +51,7 @@ def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optiona
return None
class VeoVideoGenerationNode(ComfyNodeABC):
class VeoVideoGenerationNode(comfy_io.ComfyNode):
"""
Generates videos from text prompts using Google's Veo API.
@@ -59,101 +60,93 @@ class VeoVideoGenerationNode(ComfyNodeABC):
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Text description of the video",
},
def define_schema(cls):
return comfy_io.Schema(
node_id="VeoVideoGenerationNode",
display_name="Google Veo 2 Video Generation",
category="api node/video/Veo",
description="Generates videos from text prompts using Google's Veo 2 API",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the video",
),
"aspect_ratio": (
IO.COMBO,
{
"options": ["16:9", "9:16"],
"default": "16:9",
"tooltip": "Aspect ratio of the output video",
},
comfy_io.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16"],
default="16:9",
tooltip="Aspect ratio of the output video",
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Negative text prompt to guide what to avoid in the video",
},
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid in the video",
optional=True,
),
"duration_seconds": (
IO.INT,
{
"default": 5,
"min": 5,
"max": 8,
"step": 1,
"display": "number",
"tooltip": "Duration of the output video in seconds",
},
comfy_io.Int.Input(
"duration_seconds",
default=5,
min=5,
max=8,
step=1,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Duration of the output video in seconds",
optional=True,
),
"enhance_prompt": (
IO.BOOLEAN,
{
"default": True,
"tooltip": "Whether to enhance the prompt with AI assistance",
}
comfy_io.Boolean.Input(
"enhance_prompt",
default=True,
tooltip="Whether to enhance the prompt with AI assistance",
optional=True,
),
"person_generation": (
IO.COMBO,
{
"options": ["ALLOW", "BLOCK"],
"default": "ALLOW",
"tooltip": "Whether to allow generating people in the video",
},
comfy_io.Combo.Input(
"person_generation",
options=["ALLOW", "BLOCK"],
default="ALLOW",
tooltip="Whether to allow generating people in the video",
optional=True,
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFF,
"step": 1,
"display": "number",
"control_after_generate": True,
"tooltip": "Seed for video generation (0 for random)",
},
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFF,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed for video generation (0 for random)",
optional=True,
),
"image": (IO.IMAGE, {
"default": None,
"tooltip": "Optional reference image to guide video generation",
}),
"model": (
IO.COMBO,
{
"options": ["veo-2.0-generate-001"],
"default": "veo-2.0-generate-001",
"tooltip": "Veo 2 model to use for video generation",
},
comfy_io.Image.Input(
"image",
tooltip="Optional reference image to guide video generation",
optional=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
comfy_io.Combo.Input(
"model",
options=["veo-2.0-generate-001"],
default="veo-2.0-generate-001",
tooltip="Veo 2 model to use for video generation",
optional=True,
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
RETURN_TYPES = (IO.VIDEO,)
FUNCTION = "generate_video"
CATEGORY = "api node/video/Veo"
DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API"
API_NODE = True
async def generate_video(
self,
@classmethod
async def execute(
cls,
prompt,
aspect_ratio="16:9",
negative_prompt="",
@@ -164,8 +157,6 @@ class VeoVideoGenerationNode(ComfyNodeABC):
image=None,
model="veo-2.0-generate-001",
generate_audio=False,
unique_id: Optional[str] = None,
**kwargs,
):
# Prepare the instances for the request
instances = []
@@ -202,6 +193,10 @@ class VeoVideoGenerationNode(ComfyNodeABC):
if "veo-3.0" in model:
parameters["generateAudio"] = generate_audio
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# Initial request to start video generation
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
@@ -214,7 +209,7 @@ class VeoVideoGenerationNode(ComfyNodeABC):
instances=instances,
parameters=parameters
),
auth_kwargs=kwargs,
auth_kwargs=auth,
)
initial_response = await initial_operation.execute()
@@ -248,10 +243,10 @@ class VeoVideoGenerationNode(ComfyNodeABC):
request=VeoGenVidPollRequest(
operationName=operation_name
),
auth_kwargs=kwargs,
auth_kwargs=auth,
poll_interval=5.0,
result_url_extractor=get_video_url_from_response,
node_id=unique_id,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
)
@@ -304,10 +299,10 @@ class VeoVideoGenerationNode(ComfyNodeABC):
logging.info("Video generation completed successfully")
# Convert video data to BytesIO object
video_io = io.BytesIO(video_data)
video_io = BytesIO(video_data)
# Return VideoFromFile object
return (VideoFromFile(video_io),)
return comfy_io.NodeOutput(VideoFromFile(video_io))
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
@@ -323,51 +318,104 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
"""
@classmethod
def INPUT_TYPES(s):
parent_input = super().INPUT_TYPES()
# Update model options for Veo 3
parent_input["optional"]["model"] = (
IO.COMBO,
{
"options": ["veo-3.0-generate-001", "veo-3.0-fast-generate-001"],
"default": "veo-3.0-generate-001",
"tooltip": "Veo 3 model to use for video generation",
},
def define_schema(cls):
return comfy_io.Schema(
node_id="Veo3VideoGenerationNode",
display_name="Google Veo 3 Video Generation",
category="api node/video/Veo",
description="Generates videos from text prompts using Google's Veo 3 API",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the video",
),
comfy_io.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16"],
default="16:9",
tooltip="Aspect ratio of the output video",
),
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid in the video",
optional=True,
),
comfy_io.Int.Input(
"duration_seconds",
default=8,
min=8,
max=8,
step=1,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
optional=True,
),
comfy_io.Boolean.Input(
"enhance_prompt",
default=True,
tooltip="Whether to enhance the prompt with AI assistance",
optional=True,
),
comfy_io.Combo.Input(
"person_generation",
options=["ALLOW", "BLOCK"],
default="ALLOW",
tooltip="Whether to allow generating people in the video",
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFF,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed for video generation (0 for random)",
optional=True,
),
comfy_io.Image.Input(
"image",
tooltip="Optional reference image to guide video generation",
optional=True,
),
comfy_io.Combo.Input(
"model",
options=["veo-3.0-generate-001", "veo-3.0-fast-generate-001"],
default="veo-3.0-generate-001",
tooltip="Veo 3 model to use for video generation",
optional=True,
),
comfy_io.Boolean.Input(
"generate_audio",
default=False,
tooltip="Generate audio for the video. Supported by all Veo 3 models.",
optional=True,
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
# Add generateAudio parameter
parent_input["optional"]["generate_audio"] = (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Generate audio for the video. Supported by all Veo 3 models.",
}
)
# Update duration constraints for Veo 3 (only 8 seconds supported)
parent_input["optional"]["duration_seconds"] = (
IO.INT,
{
"default": 8,
"min": 8,
"max": 8,
"step": 1,
"display": "number",
"tooltip": "Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
},
)
class VeoExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
VeoVideoGenerationNode,
Veo3VideoGenerationNode,
]
return parent_input
# Register the nodes
NODE_CLASS_MAPPINGS = {
"VeoVideoGenerationNode": VeoVideoGenerationNode,
"Veo3VideoGenerationNode": Veo3VideoGenerationNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VeoVideoGenerationNode": "Google Veo 2 Video Generation",
"Veo3VideoGenerationNode": "Google Veo 3 Video Generation",
}
async def comfy_entrypoint() -> VeoExtension:
return VeoExtension()

View File

@@ -1,49 +1,63 @@
import torch
from typing_extensions import override
import comfy.model_management
import node_helpers
from comfy_api.latest import ComfyExtension, io
class TextEncodeAceStepAudio:
class TextEncodeAceStepAudio(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"tags": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
def define_schema(cls):
return io.Schema(
node_id="TextEncodeAceStepAudio",
category="conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("tags", multiline=True, dynamic_prompts=True),
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
],
outputs=[io.Conditioning.Output()],
)
CATEGORY = "conditioning"
def encode(self, clip, tags, lyrics, lyrics_strength):
@classmethod
def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics)
conditioning = clip.encode_from_tokens_scheduled(tokens)
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
return (conditioning, )
return io.NodeOutput(conditioning)
class EmptyAceStepLatentAudio:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
class EmptyAceStepLatentAudio(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyAceStepLatentAudio",
category="latent/audio",
inputs=[
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
io.Int.Input(
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
),
],
outputs=[io.Latent.Output()],
)
@classmethod
def INPUT_TYPES(s):
return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/audio"
def generate(self, seconds, batch_size):
def execute(cls, seconds, batch_size) -> io.NodeOutput:
length = int(seconds * 44100 / 512 / 8)
latent = torch.zeros([batch_size, 8, 16, length], device=self.device)
return ({"samples": latent, "type": "audio"}, )
latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent, "type": "audio"})
NODE_CLASS_MAPPINGS = {
"TextEncodeAceStepAudio": TextEncodeAceStepAudio,
"EmptyAceStepLatentAudio": EmptyAceStepLatentAudio,
}
class AceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeAceStepAudio,
EmptyAceStepLatentAudio,
]
async def comfy_entrypoint() -> AceExtension:
return AceExtension()

View File

@@ -1,8 +1,13 @@
import numpy as np
import torch
from tqdm.auto import trange
from typing_extensions import override
import comfy.model_patcher
import comfy.samplers
import comfy.utils
import torch
import numpy as np
from tqdm.auto import trange
from comfy.k_diffusion.sampling import to_d
from comfy_api.latest import ComfyExtension, io
@torch.no_grad()
@@ -33,30 +38,29 @@ def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable
return x
class SamplerLCMUpscale:
upscale_methods = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
class SamplerLCMUpscale(io.ComfyNode):
UPSCALE_METHODS = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
@classmethod
def INPUT_TYPES(s):
return {"required":
{"scale_ratio": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.01}),
"scale_steps": ("INT", {"default": -1, "min": -1, "max": 1000, "step": 1}),
"upscale_method": (s.upscale_methods,),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="SamplerLCMUpscale",
category="sampling/custom_sampling/samplers",
inputs=[
io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01),
io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1),
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
],
outputs=[io.Sampler.Output()],
)
FUNCTION = "get_sampler"
def get_sampler(self, scale_ratio, scale_steps, upscale_method):
@classmethod
def execute(cls, scale_ratio, scale_steps, upscale_method) -> io.NodeOutput:
if scale_steps < 0:
scale_steps = None
sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method})
return (sampler, )
return io.NodeOutput(sampler)
from comfy.k_diffusion.sampling import to_d
import comfy.model_patcher
@torch.no_grad()
def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
@@ -82,30 +86,36 @@ def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=No
return x
class SamplerEulerCFGpp:
class SamplerEulerCFGpp(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required":
{"version": (["regular", "alternative"],),}
}
RETURN_TYPES = ("SAMPLER",)
# CATEGORY = "sampling/custom_sampling/samplers"
CATEGORY = "_for_testing"
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="SamplerEulerCFGpp",
display_name="SamplerEulerCFG++",
category="_for_testing", # "sampling/custom_sampling/samplers"
inputs=[
io.Combo.Input("version", options=["regular", "alternative"]),
],
outputs=[io.Sampler.Output()],
is_experimental=True,
)
FUNCTION = "get_sampler"
def get_sampler(self, version):
@classmethod
def execute(cls, version) -> io.NodeOutput:
if version == "alternative":
sampler = comfy.samplers.KSAMPLER(sample_euler_pp)
else:
sampler = comfy.samplers.ksampler("euler_cfg_pp")
return (sampler, )
return io.NodeOutput(sampler)
NODE_CLASS_MAPPINGS = {
"SamplerLCMUpscale": SamplerLCMUpscale,
"SamplerEulerCFGpp": SamplerEulerCFGpp,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SamplerEulerCFGpp": "SamplerEulerCFG++",
}
class AdvancedSamplersExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SamplerLCMUpscale,
SamplerEulerCFGpp,
]
async def comfy_entrypoint() -> AdvancedSamplersExtension:
return AdvancedSamplersExtension()

View File

@@ -1,4 +1,8 @@
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def project(v0, v1):
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
@@ -6,22 +10,45 @@ def project(v0, v1):
v0_orthogonal = v0 - v0_parallel
return v0_parallel, v0_orthogonal
class APG:
class APG(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}),
"norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}),
"momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "sampling/custom_sampling"
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="APG",
display_name="Adaptive Projected Guidance",
category="sampling/custom_sampling",
inputs=[
io.Model.Input("model"),
io.Float.Input(
"eta",
default=1.0,
min=-10.0,
max=10.0,
step=0.01,
tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.",
),
io.Float.Input(
"norm_threshold",
default=5.0,
min=0.0,
max=50.0,
step=0.1,
tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.",
),
io.Float.Input(
"momentum",
default=0.0,
min=-5.0,
max=1.0,
step=0.01,
tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.",
),
],
outputs=[io.Model.Output()],
)
def patch(self, model, eta, norm_threshold, momentum):
@classmethod
def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput:
running_avg = 0
prev_sigma = None
@@ -65,12 +92,15 @@ class APG:
m = model.clone()
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
return (m,)
return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"APG": APG,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"APG": "Adaptive Projected Guidance",
}
class ApgExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
APG,
]
async def comfy_entrypoint() -> ApgExtension:
return ApgExtension()

View File

@@ -1,3 +1,7 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def attention_multiply(attn, model, q, k, v, out):
m = model.clone()
@@ -16,57 +20,71 @@ def attention_multiply(attn, model, q, k, v, out):
return m
class UNetSelfAttentionMultiply:
class UNetSelfAttentionMultiply(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="UNetSelfAttentionMultiply",
category="_for_testing/attention_experiments",
inputs=[
io.Model.Input("model"),
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
],
outputs=[io.Model.Output()],
is_experimental=True,
)
CATEGORY = "_for_testing/attention_experiments"
def patch(self, model, q, k, v, out):
@classmethod
def execute(cls, model, q, k, v, out) -> io.NodeOutput:
m = attention_multiply("attn1", model, q, k, v, out)
return (m, )
return io.NodeOutput(m)
class UNetCrossAttentionMultiply:
class UNetCrossAttentionMultiply(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="UNetCrossAttentionMultiply",
category="_for_testing/attention_experiments",
inputs=[
io.Model.Input("model"),
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
],
outputs=[io.Model.Output()],
is_experimental=True,
)
CATEGORY = "_for_testing/attention_experiments"
def patch(self, model, q, k, v, out):
@classmethod
def execute(cls, model, q, k, v, out) -> io.NodeOutput:
m = attention_multiply("attn2", model, q, k, v, out)
return (m, )
return io.NodeOutput(m)
class CLIPAttentionMultiply:
class CLIPAttentionMultiply(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip": ("CLIP",),
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "patch"
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="CLIPAttentionMultiply",
category="_for_testing/attention_experiments",
inputs=[
io.Clip.Input("clip"),
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
],
outputs=[io.Clip.Output()],
is_experimental=True,
)
CATEGORY = "_for_testing/attention_experiments"
def patch(self, clip, q, k, v, out):
@classmethod
def execute(cls, clip, q, k, v, out) -> io.NodeOutput:
m = clip.clone()
sd = m.patcher.model_state_dict()
@@ -79,23 +97,28 @@ class CLIPAttentionMultiply:
m.add_patches({key: (None,)}, 0.0, v)
if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"):
m.add_patches({key: (None,)}, 0.0, out)
return (m, )
return io.NodeOutput(m)
class UNetTemporalAttentionMultiply:
class UNetTemporalAttentionMultiply(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"self_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"self_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"cross_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"cross_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="UNetTemporalAttentionMultiply",
category="_for_testing/attention_experiments",
inputs=[
io.Model.Input("model"),
io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("self_temporal", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("cross_structural", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("cross_temporal", default=1.0, min=0.0, max=10.0, step=0.01),
],
outputs=[io.Model.Output()],
is_experimental=True,
)
CATEGORY = "_for_testing/attention_experiments"
def patch(self, model, self_structural, self_temporal, cross_structural, cross_temporal):
@classmethod
def execute(cls, model, self_structural, self_temporal, cross_structural, cross_temporal) -> io.NodeOutput:
m = model.clone()
sd = model.model_state_dict()
@@ -110,11 +133,18 @@ class UNetTemporalAttentionMultiply:
m.add_patches({k: (None,)}, 0.0, cross_temporal)
else:
m.add_patches({k: (None,)}, 0.0, cross_structural)
return (m, )
return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"UNetSelfAttentionMultiply": UNetSelfAttentionMultiply,
"UNetCrossAttentionMultiply": UNetCrossAttentionMultiply,
"CLIPAttentionMultiply": CLIPAttentionMultiply,
"UNetTemporalAttentionMultiply": UNetTemporalAttentionMultiply,
}
class AttentionMultiplyExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
UNetSelfAttentionMultiply,
UNetCrossAttentionMultiply,
CLIPAttentionMultiply,
UNetTemporalAttentionMultiply,
]
async def comfy_entrypoint() -> AttentionMultiplyExtension:
return AttentionMultiplyExtension()

View File

@@ -0,0 +1,44 @@
import folder_paths
import comfy.audio_encoders.audio_encoders
import comfy.utils
class AudioEncoderLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio_encoder_name": (folder_paths.get_filename_list("audio_encoders"), ),
}}
RETURN_TYPES = ("AUDIO_ENCODER",)
FUNCTION = "load_model"
CATEGORY = "loaders"
def load_model(self, audio_encoder_name):
audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name)
sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True)
audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd)
if audio_encoder is None:
raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.")
return (audio_encoder,)
class AudioEncoderEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio_encoder": ("AUDIO_ENCODER",),
"audio": ("AUDIO",),
}}
RETURN_TYPES = ("AUDIO_ENCODER_OUTPUT",)
FUNCTION = "encode"
CATEGORY = "conditioning"
def encode(self, audio_encoder, audio):
output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"])
return (output,)
NODE_CLASS_MAPPINGS = {
"AudioEncoderLoader": AudioEncoderLoader,
"AudioEncoderEncode": AudioEncoderEncode,
}

View File

@@ -0,0 +1,493 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union
from comfy_api.latest import io, ComfyExtension
import comfy.patcher_extension
import logging
import torch
import comfy.model_patcher
if TYPE_CHECKING:
from uuid import UUID
def easycache_forward_wrapper(executor, *args, **kwargs):
# get values from args
x: torch.Tensor = args[0]
transformer_options: dict[str] = args[-1]
if not isinstance(transformer_options, dict):
transformer_options = kwargs.get("transformer_options")
if not transformer_options:
transformer_options = args[-2]
easycache: EasyCacheHolder = transformer_options["easycache"]
sigmas = transformer_options["sigmas"]
uuids = transformer_options["uuids"]
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
return executor(*args, **kwargs)
# prepare next x_prev
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
next_x_prev = x
input_change = None
do_easycache = easycache.should_do_easycache(sigmas)
if do_easycache:
easycache.check_metadata(x)
# if first cond marked this step for skipping, skip it and use appropriate cached values
if easycache.skip_current_step:
if easycache.verbose:
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
return easycache.apply_cache_diff(x, uuids)
if easycache.initial_step:
easycache.first_cond_uuid = uuids[0]
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
easycache.initial_step = False
if has_first_cond_uuid:
if easycache.has_x_prev_subsampled():
input_change = (easycache.subsample(x, uuids, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
easycache.cumulative_change_rate += approx_output_change_rate
if easycache.cumulative_change_rate < easycache.reuse_threshold:
if easycache.verbose:
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
# other conds should also skip this step, and instead use their cached values
easycache.skip_current_step = True
return easycache.apply_cache_diff(x, uuids)
else:
if easycache.verbose:
logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
easycache.cumulative_change_rate = 0.0
output: torch.Tensor = executor(*args, **kwargs)
if has_first_cond_uuid and easycache.has_output_prev_norm():
output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
if easycache.verbose:
output_change_rate = output_change / easycache.output_prev_norm
easycache.output_change_rates.append(output_change_rate.item())
if easycache.has_relative_transformation_rate():
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
if easycache.verbose:
logging.info(f"EasyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
if input_change is not None:
easycache.relative_transformation_rate = output_change / input_change
if easycache.verbose:
logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
# TODO: allow cache_diff to be offloaded
easycache.update_cache_diff(output, next_x_prev, uuids)
if has_first_cond_uuid:
easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids)
easycache.output_prev_subsampled = easycache.subsample(output, uuids)
easycache.output_prev_norm = output.flatten().abs().mean()
if easycache.verbose:
logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
return output
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
# get values from args
x: torch.Tensor = args[0]
timestep: float = args[1]
model_options: dict[str] = args[2]
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
if easycache.is_past_end_timestep(timestep):
return executor(*args, **kwargs)
# prepare next x_prev
next_x_prev = x
input_change = None
do_easycache = easycache.should_do_easycache(timestep)
if do_easycache:
easycache.check_metadata(x)
if easycache.has_x_prev_subsampled():
if easycache.has_x_prev_subsampled():
input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
easycache.cumulative_change_rate += approx_output_change_rate
if easycache.cumulative_change_rate < easycache.reuse_threshold:
if easycache.verbose:
logging.info(f"LazyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
# other conds should also skip this step, and instead use their cached values
easycache.skip_current_step = True
return easycache.apply_cache_diff(x)
else:
if easycache.verbose:
logging.info(f"LazyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
easycache.cumulative_change_rate = 0.0
output: torch.Tensor = executor(*args, **kwargs)
if easycache.has_output_prev_norm():
output_change = (easycache.subsample(output, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
if easycache.verbose:
output_change_rate = output_change / easycache.output_prev_norm
easycache.output_change_rates.append(output_change_rate.item())
if easycache.has_relative_transformation_rate():
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
if easycache.verbose:
logging.info(f"LazyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
if input_change is not None:
easycache.relative_transformation_rate = output_change / input_change
if easycache.verbose:
logging.info(f"LazyCache [verbose] - output_change_rate: {output_change_rate}")
# TODO: allow cache_diff to be offloaded
easycache.update_cache_diff(output, next_x_prev)
easycache.x_prev_subsampled = easycache.subsample(next_x_prev)
easycache.output_prev_subsampled = easycache.subsample(output)
easycache.output_prev_norm = output.flatten().abs().mean()
if easycache.verbose:
logging.info(f"LazyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
return output
def easycache_calc_cond_batch_wrapper(executor, *args, **kwargs):
model_options = args[-1]
easycache: EasyCacheHolder = model_options["transformer_options"]["easycache"]
easycache.skip_current_step = False
# TODO: check if first_cond_uuid is active at this timestep; otherwise, EasyCache needs to be partially reset
return executor(*args, **kwargs)
def easycache_sample_wrapper(executor, *args, **kwargs):
"""
This OUTER_SAMPLE wrapper makes sure easycache is prepped for current run, and all memory usage is cleared at the end.
"""
try:
guider = executor.class_obj
orig_model_options = guider.model_options
guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options)
# clone and prepare timesteps
guider.model_options["transformer_options"]["easycache"] = guider.model_options["transformer_options"]["easycache"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling)
easycache: Union[EasyCacheHolder, LazyCacheHolder] = guider.model_options['transformer_options']['easycache']
logging.info(f"{easycache.name} enabled - threshold: {easycache.reuse_threshold}, start_percent: {easycache.start_percent}, end_percent: {easycache.end_percent}")
return executor(*args, **kwargs)
finally:
easycache = guider.model_options['transformer_options']['easycache']
output_change_rates = easycache.output_change_rates
approx_output_change_rates = easycache.approx_output_change_rates
if easycache.verbose:
logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}")
logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}")
total_steps = len(args[3])-1
logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({total_steps/(total_steps-easycache.total_steps_skipped):.2f}x speedup).")
easycache.reset()
guider.model_options = orig_model_options
class EasyCacheHolder:
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
self.name = "EasyCache"
self.reuse_threshold = reuse_threshold
self.start_percent = start_percent
self.end_percent = end_percent
self.subsample_factor = subsample_factor
self.offload_cache_diff = offload_cache_diff
self.verbose = verbose
# timestep values
self.start_t = 0.0
self.end_t = 0.0
# control values
self.relative_transformation_rate: float = None
self.cumulative_change_rate = 0.0
self.initial_step = True
self.skip_current_step = False
# cache values
self.first_cond_uuid = None
self.x_prev_subsampled: torch.Tensor = None
self.output_prev_subsampled: torch.Tensor = None
self.output_prev_norm: torch.Tensor = None
self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {}
self.output_change_rates = []
self.approx_output_change_rates = []
self.total_steps_skipped = 0
# how to deal with mismatched dims
self.allow_mismatch = True
self.cut_from_start = True
self.state_metadata = None
def is_past_end_timestep(self, timestep: float) -> bool:
return not (timestep[0] > self.end_t).item()
def should_do_easycache(self, timestep: float) -> bool:
return (timestep[0] <= self.start_t).item()
def has_x_prev_subsampled(self) -> bool:
return self.x_prev_subsampled is not None
def has_output_prev_subsampled(self) -> bool:
return self.output_prev_subsampled is not None
def has_output_prev_norm(self) -> bool:
return self.output_prev_norm is not None
def has_relative_transformation_rate(self) -> bool:
return self.relative_transformation_rate is not None
def prepare_timesteps(self, model_sampling):
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
return self
def subsample(self, x: torch.Tensor, uuids: list[UUID], clone: bool = True) -> torch.Tensor:
batch_offset = x.shape[0] // len(uuids)
uuid_idx = uuids.index(self.first_cond_uuid)
if self.subsample_factor > 1:
to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ..., ::self.subsample_factor, ::self.subsample_factor]
if clone:
return to_return.clone()
return to_return
to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ...]
if clone:
return to_return.clone()
return to_return
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
if self.first_cond_uuid in uuids:
self.total_steps_skipped += 1
batch_offset = x.shape[0] // len(uuids)
for i, uuid in enumerate(uuids):
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
if not self.allow_mismatch:
raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good")
slicing = []
skip_this_dim = True
for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape):
if skip_this_dim:
skip_this_dim = False
continue
if dim_u != dim_x:
if self.cut_from_start:
slicing.append(slice(dim_x-dim_u, None))
else:
slicing.append(slice(None, dim_u))
else:
slicing.append(slice(None))
slicing = [slice(i*batch_offset,(i+1)*batch_offset)] + slicing
x = x[slicing]
x += self.uuid_cache_diffs[uuid].to(x.device)
return x
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
# if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
if output.shape[1:] != x.shape[1:]:
if not self.allow_mismatch:
raise ValueError(f"Output dims {output.shape} don't match x dims {x.shape} - this is no good")
slicing = []
skip_dim = True
for dim_o, dim_x in zip(output.shape, x.shape):
if not skip_dim and dim_o != dim_x:
if self.cut_from_start:
slicing.append(slice(dim_x-dim_o, None))
else:
slicing.append(slice(None, dim_o))
else:
slicing.append(slice(None))
skip_dim = False
x = x[slicing]
diff = output - x
batch_offset = diff.shape[0] // len(uuids)
for i, uuid in enumerate(uuids):
self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
return self.first_cond_uuid in uuids
def check_metadata(self, x: torch.Tensor) -> bool:
metadata = (x.device, x.dtype, x.shape[1:])
if self.state_metadata is None:
self.state_metadata = metadata
return True
if metadata == self.state_metadata:
return True
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
self.reset()
return False
def reset(self):
self.relative_transformation_rate = 0.0
self.cumulative_change_rate = 0.0
self.initial_step = True
self.skip_current_step = False
self.output_change_rates = []
self.first_cond_uuid = None
del self.x_prev_subsampled
self.x_prev_subsampled = None
del self.output_prev_subsampled
self.output_prev_subsampled = None
del self.output_prev_norm
self.output_prev_norm = None
del self.uuid_cache_diffs
self.uuid_cache_diffs = {}
self.total_steps_skipped = 0
self.state_metadata = None
return self
def clone(self):
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
class EasyCacheNode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="EasyCache",
display_name="EasyCache",
description="Native EasyCache implementation.",
category="advanced/debug/model",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to add EasyCache to."),
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."),
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of EasyCache."),
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of EasyCache."),
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
],
outputs=[
io.Model.Output(tooltip="The model with EasyCache."),
],
)
@classmethod
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
model = model.clone()
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper)
return io.NodeOutput(model)
class LazyCacheHolder:
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
self.name = "LazyCache"
self.reuse_threshold = reuse_threshold
self.start_percent = start_percent
self.end_percent = end_percent
self.subsample_factor = subsample_factor
self.offload_cache_diff = offload_cache_diff
self.verbose = verbose
# timestep values
self.start_t = 0.0
self.end_t = 0.0
# control values
self.relative_transformation_rate: float = None
self.cumulative_change_rate = 0.0
self.initial_step = True
# cache values
self.x_prev_subsampled: torch.Tensor = None
self.output_prev_subsampled: torch.Tensor = None
self.output_prev_norm: torch.Tensor = None
self.cache_diff: torch.Tensor = None
self.output_change_rates = []
self.approx_output_change_rates = []
self.total_steps_skipped = 0
self.state_metadata = None
def has_cache_diff(self) -> bool:
return self.cache_diff is not None
def is_past_end_timestep(self, timestep: float) -> bool:
return not (timestep[0] > self.end_t).item()
def should_do_easycache(self, timestep: float) -> bool:
return (timestep[0] <= self.start_t).item()
def has_x_prev_subsampled(self) -> bool:
return self.x_prev_subsampled is not None
def has_output_prev_subsampled(self) -> bool:
return self.output_prev_subsampled is not None
def has_output_prev_norm(self) -> bool:
return self.output_prev_norm is not None
def has_relative_transformation_rate(self) -> bool:
return self.relative_transformation_rate is not None
def prepare_timesteps(self, model_sampling):
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
return self
def subsample(self, x: torch.Tensor, clone: bool = True) -> torch.Tensor:
if self.subsample_factor > 1:
to_return = x[..., ::self.subsample_factor, ::self.subsample_factor]
if clone:
return to_return.clone()
return to_return
if clone:
return x.clone()
return x
def apply_cache_diff(self, x: torch.Tensor):
self.total_steps_skipped += 1
return x + self.cache_diff.to(x.device)
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor):
self.cache_diff = output - x
def check_metadata(self, x: torch.Tensor) -> bool:
metadata = (x.device, x.dtype, x.shape)
if self.state_metadata is None:
self.state_metadata = metadata
return True
if metadata == self.state_metadata:
return True
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
self.reset()
return False
def reset(self):
self.relative_transformation_rate = 0.0
self.cumulative_change_rate = 0.0
self.initial_step = True
self.output_change_rates = []
self.approx_output_change_rates = []
del self.cache_diff
self.cache_diff = None
del self.x_prev_subsampled
self.x_prev_subsampled = None
del self.output_prev_subsampled
self.output_prev_subsampled = None
del self.output_prev_norm
self.output_prev_norm = None
self.total_steps_skipped = 0
self.state_metadata = None
return self
def clone(self):
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
class LazyCacheNode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LazyCache",
display_name="LazyCache",
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
category="advanced/debug/model",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to add LazyCache to."),
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."),
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of LazyCache."),
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of LazyCache."),
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
],
outputs=[
io.Model.Output(tooltip="The model with LazyCache."),
],
)
@classmethod
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
model = model.clone()
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper)
return io.NodeOutput(model)
class EasyCacheExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EasyCacheNode,
LazyCacheNode,
]
def comfy_entrypoint():
return EasyCacheExtension()

View File

@@ -0,0 +1,503 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union
from comfy_api.latest import io, ComfyExtension
import comfy.patcher_extension
import logging
import torch
import math
import comfy.model_patcher
if TYPE_CHECKING:
from uuid import UUID
def easysortblock_predict_noise_wrapper(executor, *args, **kwargs):
# get values from args
x: torch.Tensor = args[0]
timestep: float = args[1]
model_options: dict[str] = args[2]
easycache: EasySortblockHolder = model_options["transformer_options"]["easycache"]
# initialize predict_ratios
if easycache.initial_step:
sample_sigmas = model_options["transformer_options"]["sample_sigmas"]
relevant_sigmas = []
for i,sigma in enumerate(sample_sigmas):
if easycache.check_if_within_timesteps(sigma):
relevant_sigmas.append((i, sigma))
start_index = relevant_sigmas[0][0]
end_index = relevant_sigmas[-1][0]
easycache.predict_ratios = torch.linspace(easycache.start_predict_ratio, easycache.end_predict_ratio, end_index - start_index + 1)
easycache.predict_start_index = start_index
easycache.skip_current_step = False
if easycache.is_past_end_timestep(timestep):
return executor(*args, **kwargs)
# prepare next x_prev
next_x_prev = x
input_change = None
do_easycache = easycache.should_do_easycache(timestep)
if do_easycache:
easycache.check_metadata(x)
if easycache.has_x_prev_subsampled():
if easycache.has_x_prev_subsampled():
input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
easycache.cumulative_change_rate += approx_output_change_rate
if easycache.cumulative_change_rate < easycache.reuse_threshold:
if easycache.verbose:
logging.info(f"EasySortblock [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
# other conds should also skip this step
easycache.skip_current_step = True
easycache.steps_skipped.append(easycache.step_count)
else:
if easycache.verbose:
logging.info(f"EasySortblock [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
easycache.cumulative_change_rate = 0.0
output: torch.Tensor = executor(*args, **kwargs)
if easycache.has_output_prev_norm():
output_change = (easycache.subsample(output, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
if easycache.verbose:
output_change_rate = output_change / easycache.output_prev_norm
easycache.output_change_rates.append(output_change_rate.item())
if easycache.has_relative_transformation_rate():
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
if easycache.verbose:
logging.info(f"EasySortblock [verbose] - approx_output_change_rate: {approx_output_change_rate}")
if input_change is not None:
easycache.relative_transformation_rate = output_change / input_change
if easycache.verbose:
logging.info(f"EasySortblock [verbose] - output_change_rate: {output_change_rate}")
easycache.x_prev_subsampled = easycache.subsample(next_x_prev)
easycache.output_prev_subsampled = easycache.subsample(output)
easycache.output_prev_norm = output.flatten().abs().mean()
if easycache.verbose:
logging.info(f"EasySortblock [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
# increment step count
easycache.step_count += 1
easycache.initial_step = False
return output
def easysortblock_outer_sample_wrapper(executor, *args, **kwargs):
"""
This OUTER_SAMPLE wrapper makes sure EasySortblock is prepped for current run, and all memory usage is cleared at the end.
"""
try:
guider = executor.class_obj
orig_model_options = guider.model_options
guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options)
# clone and prepare timesteps
guider.model_options["transformer_options"]["easycache"] = guider.model_options["transformer_options"]["easycache"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling)
easycache: EasySortblockHolder = guider.model_options['transformer_options']['easycache']
logging.info(f"{easycache.name} enabled - threshold: {easycache.reuse_threshold}, start_percent: {easycache.start_percent}, end_percent: {easycache.end_percent}")
return executor(*args, **kwargs)
finally:
easycache = guider.model_options['transformer_options']['easycache']
output_change_rates = easycache.output_change_rates
approx_output_change_rates = easycache.approx_output_change_rates
if easycache.verbose:
logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}")
logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}")
total_steps = len(args[3])-1
logging.info(f"{easycache.name} - skipped {len(easycache.steps_skipped)}/{total_steps} steps")# ({total_steps/(total_steps-easycache.total_steps_skipped):.2f}x speedup).")
logging.info(f"{easycache.name} - skipped steps: {easycache.steps_skipped}")
easycache.reset()
guider.model_options = orig_model_options
def model_forward_wrapper(executor, *args, **kwargs):
# TODO: make work with batches of conds
transformer_options: dict[str] = args[-1]
if not isinstance(transformer_options, dict):
transformer_options = kwargs.get("transformer_options")
if not transformer_options:
transformer_options = args[-2]
sigmas = transformer_options["sigmas"]
sb_holder: EasySortblockHolder = transformer_options["easycache"]
# if initial step, prepare everything for Sortblock
if sb_holder.initial_step:
logging.info(f"EasySortblock: inside model {executor.class_obj.__class__.__name__}")
# TODO: generalize for other models
# these won't stick around past this step; should store on sb_holder instead
logging.info(f"EasySortblock: preparing {len(executor.class_obj.double_blocks)} double blocks and {len(executor.class_obj.single_blocks)} single blocks")
if hasattr(executor.class_obj, "double_blocks"):
for block in executor.class_obj.double_blocks:
prepare_block(block, sb_holder)
if hasattr(executor.class_obj, "single_blocks"):
for block in executor.class_obj.single_blocks:
prepare_block(block, sb_holder)
if hasattr(executor.class_obj, "blocks"):
for block in executor.class_obj.block:
prepare_block(block, sb_holder)
if sb_holder.skip_current_step:
predict_index = max(0, sb_holder.step_count - sb_holder.predict_start_index)
predict_ratio = sb_holder.predict_ratios[predict_index]
logging.info(f"EasySortblock: skipping step {sb_holder.step_count}, predict_ratio: {predict_ratio}")
# reuse_ratio = 1.0 - predict_ratio
for block_type, blocks in sb_holder.blocks_per_type.items():
for block in blocks:
cache: BlockCache = block.__block_cache
cache.allowed_to_skip = False
sorted_blocks = sorted(blocks, key=lambda x: (x.__block_cache.consecutive_skipped_steps, x.__block_cache.prev_change_rate))
# for block in sorted_blocks:
# pass
threshold_index = int(len(sorted_blocks) * predict_ratio)
# blocks with lower similarity are marked for recomputation
for block in sorted_blocks[:threshold_index]:
cache: BlockCache = block.__block_cache
cache.allowed_to_skip = True
logging.info(f"EasySortblock: skip block {block.__class__.__name__} - consecutive_skipped_steps: {block.__block_cache.consecutive_skipped_steps}, prev_change_rate: {block.__block_cache.prev_change_rate}, index: {block.__block_cache.block_index}")
not_skipped = [block for block in blocks if not block.__block_cache.allowed_to_skip]
for block in not_skipped:
logging.info(f"EasySortblock: reco block {block.__class__.__name__} - consecutive_skipped_steps: {block.__block_cache.consecutive_skipped_steps}, prev_change_rate: {block.__block_cache.prev_change_rate}, index: {block.__block_cache.block_index}")
logging.info(f"EasySortblock: for {block_type}, selected {len(sorted_blocks[:threshold_index])} blocks for prediction and {len(sorted_blocks[threshold_index:])} blocks for recomputation")
# return executor(*args, **kwargs)
to_return = executor(*args, **kwargs)
return to_return
def block_forward_factory(func, block):
def block_forward_wrapper(*args, **kwargs):
transformer_options: dict[str] = kwargs.get("transformer_options")
sigmas = transformer_options["sigmas"]
sb_holder: EasySortblockHolder = transformer_options["easycache"]
cache: BlockCache = block.__block_cache
# make sure stream count is properly set for this block
if sb_holder.initial_step:
sb_holder.add_to_blocks_per_type(block, transformer_options['block'][0])
cache.block_index = transformer_options['block'][1]
cache.stream_count = transformer_options['block'][2]
if sb_holder.is_past_end_timestep(sigmas):
return func(*args, **kwargs)
# do sortblock stuff
x = cache.get_next_x_prev(args, kwargs)
# prepare next_x_prev
next_x_prev = cache.get_next_x_prev(args, kwargs, clone=True)
input_change = None
do_sortblock = sb_holder.should_do_easycache(sigmas)
if do_sortblock:
# TODO: checkmetadata
if cache.has_x_prev_subsampled():
input_change = (cache.subsample(x, clone=False) - cache.x_prev_subsampled).flatten().abs().mean()
if cache.has_output_prev_norm() and cache.has_relative_transformation_rate():
approx_output_change_rate = (cache.relative_transformation_rate * input_change) / cache.output_prev_norm
cache.cumulative_change_rate += approx_output_change_rate
if cache.allowed_to_skip:
# if cache.cumulative_change_rate < sb_holder.reuse_threshold:
# accumulate error + skip block
# cache.want_to_skip = True
# if cache.allowed_to_skip:
cache.consecutive_skipped_steps += 1
cache.prev_change_rate = approx_output_change_rate
return cache.apply_cache_diff(x, sb_holder)
else:
# reset error; NOT skipping block and recalculating
cache.cumulative_change_rate = 0.0
cache.prev_change_rate = approx_output_change_rate
cache.want_to_skip = False
cache.consecutive_skipped_steps = 0
# output_raw is expected to have cache.stream_count elements if count is greaater than 1 (double block, etc.)
output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]] = func(*args, **kwargs)
# if more than one stream from block, only use first one
if isinstance(output_raw, tuple):
output = output_raw[0]
else:
output = output_raw
if cache.has_output_prev_norm():
output_change = (cache.subsample(output, clone=False) - cache.output_prev_subsampled).flatten().abs().mean()
# if verbose in future
output_change_rate = output_change / cache.output_prev_norm
cache.output_change_rates.append(output_change_rate.item())
if cache.has_relative_transformation_rate():
approx_output_change_rate = (cache.relative_transformation_rate * input_change) / cache.output_prev_norm
cache.approx_output_change_rates.append(approx_output_change_rate.item())
if input_change is not None:
cache.relative_transformation_rate = output_change / input_change
# TODO: allow cache_diff to be offloaded
cache.update_cache_diff(output_raw, next_x_prev)
cache.x_prev_subsampled = cache.subsample(next_x_prev)
cache.output_prev_subsampled = cache.subsample(output)
cache.output_prev_norm = output.flatten().abs().mean()
return output_raw
return block_forward_wrapper
def prepare_block(block, sb_holder: EasySortblockHolder, stream_count: int=1):
sb_holder.add_to_all_blocks(block)
block.__original_forward = block.forward
block.forward = block_forward_factory(block.__original_forward, block)
block.__block_cache = BlockCache(subsample_factor=sb_holder.subsample_factor, verbose=sb_holder.verbose)
def clean_block(block):
block.forward = block.__original_forward
del block.__original_forward
del block.__block_cache
class BlockCache:
def __init__(self, subsample_factor: int=8, verbose: bool=False):
self.subsample_factor = subsample_factor
self.verbose = verbose
self.stream_count = 1
self.block_index = 0
# control values
self.relative_transformation_rate: float = None
self.cumulative_change_rate = 0.0
self.prev_change_rate = 0.0
# cached values
self.x_prev_subsampled: torch.Tensor = None
self.output_prev_subsampled: torch.Tensor = None
self.output_prev_norm: torch.Tensor = None
self.cache_diff: list[torch.Tensor] = []
self.output_change_rates = []
self.approx_output_change_rates = []
self.steps_skipped: list[int] = []
self.consecutive_skipped_steps = 0
# self.state_metadata = None
self.want_to_skip = False
self.allowed_to_skip = False
def has_cache_diff(self) -> bool:
return self.cache_diff[0] is not None
def has_x_prev_subsampled(self) -> bool:
return self.x_prev_subsampled is not None
def has_output_prev_subsampled(self) -> bool:
return self.output_prev_subsampled is not None
def has_output_prev_norm(self) -> bool:
return self.output_prev_norm is not None
def has_relative_transformation_rate(self) -> bool:
return self.relative_transformation_rate is not None
def get_next_x_prev(self, d_args: tuple[torch.Tensor, ...], d_kwargs: dict[str, torch.Tensor], clone: bool=False) -> tuple[torch.Tensor, ...]:
if self.stream_count == 1:
if clone:
return d_args[0].clone()
return d_args[0]
keys = list(d_kwargs.keys())[:self.stream_count]
orig_inputs = []
for key in keys:
if clone:
orig_inputs.append(d_kwargs[key].clone())
else:
orig_inputs.append(d_kwargs[key])
return tuple(orig_inputs)
def subsample(self, x: Union[torch.Tensor, tuple[torch.Tensor, ...]], clone: bool = True) -> torch.Tensor:
# subsample only the first compoenent
if isinstance(x, tuple):
return self.subsample(x[0], clone)
if self.subsample_factor > 1:
to_return = x[..., ::self.subsample_factor, ::self.subsample_factor]
if clone:
return to_return.clone()
return to_return
if clone:
return x.clone()
return x
def apply_cache_diff(self, x: Union[torch.Tensor, tuple[torch.Tensor, ...]], sb_holder: EasySortblockHolder):
self.steps_skipped.append(sb_holder.step_count)
if not isinstance(x, tuple):
x = (x, )
to_return = tuple([x[i] + self.cache_diff[i] for i in range(self.stream_count)])
if len(to_return) == 1:
return to_return[0]
return to_return
def update_cache_diff(self, output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]], x: Union[torch.Tensor, tuple[torch.Tensor, ...]]):
if not isinstance(output_raw, tuple):
output_raw = (output_raw, )
if not isinstance(x, tuple):
x = (x, )
self.cache_diff = tuple([output_raw[i] - x[i] for i in range(self.stream_count)])
def reset(self):
self.relative_transformation_rate = 0.0
self.cumulative_change_rate = 0.0
self.prev_change_rate = 0.0
self.x_prev_subsampled = None
self.output_prev_subsampled = None
self.output_prev_norm = None
self.cache_diff = []
self.output_change_rates = []
self.approx_output_change_rates = []
self.steps_skipped = []
self.consecutive_skipped_steps = 0
self.want_to_skip = False
self.allowed_to_skip = False
return self
class EasySortblockHolder:
def __init__(self, reuse_threshold: float, start_predict_ratio: float, end_predict_ratio: float, max_skipped_steps: int,
start_percent: float, end_percent: float, subsample_factor: int, verbose: bool=False):
self.name = "EasySortblock"
self.reuse_threshold = reuse_threshold
self.start_predict_ratio = start_predict_ratio
self.end_predict_ratio = end_predict_ratio
self.max_skipped_steps = max_skipped_steps
self.start_percent = start_percent
self.end_percent = end_percent
self.subsample_factor = subsample_factor
self.verbose = verbose
# timestep values
self.start_t = 0.0
self.end_t = 0.0
# control values
self.relative_transformation_rate: float = None
self.cumulative_change_rate = 0.0
self.initial_step = True
self.step_count = 0
self.predict_ratios = []
self.skip_current_step = False
self.predict_start_index = 0
# cache values
self.x_prev_subsampled: torch.Tensor = None
self.output_prev_subsampled: torch.Tensor = None
self.output_prev_norm: torch.Tensor = None
self.steps_skipped: list[int] = []
self.output_change_rates = []
self.approx_output_change_rates = []
self.state_metadata = None
self.all_blocks = []
self.blocks_per_type = {}
def add_to_all_blocks(self, block):
self.all_blocks.append(block)
def add_to_blocks_per_type(self, block, block_type: str):
self.blocks_per_type.setdefault(block_type, []).append(block)
def is_past_end_timestep(self, timestep: float) -> bool:
return not (timestep[0] > self.end_t).item()
def should_do_easycache(self, timestep: float) -> bool:
return (timestep[0] <= self.start_t).item()
def check_if_within_timesteps(self, timestep: Union[float, torch.Tensor]) -> bool:
return (timestep <= self.start_t).item() and (timestep > self.end_t).item()
def has_x_prev_subsampled(self) -> bool:
return self.x_prev_subsampled is not None
def has_output_prev_subsampled(self) -> bool:
return self.output_prev_subsampled is not None
def has_output_prev_norm(self) -> bool:
return self.output_prev_norm is not None
def has_relative_transformation_rate(self) -> bool:
return self.relative_transformation_rate is not None
def prepare_timesteps(self, model_sampling):
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
return self
def subsample(self, x: torch.Tensor, clone: bool = True) -> torch.Tensor:
if self.subsample_factor > 1:
to_return = x[..., ::self.subsample_factor, ::self.subsample_factor]
if clone:
return to_return.clone()
return to_return
if clone:
return x.clone()
return x
def check_metadata(self, x: torch.Tensor) -> bool:
metadata = (x.device, x.dtype, x.shape)
if self.state_metadata is None:
self.state_metadata = metadata
return True
if metadata == self.state_metadata:
return True
logging.warning(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
self.reset()
return False
def reset(self):
logging.info(f"EasySortblock: resetting {len(self.all_blocks)} blocks")
for block in self.all_blocks:
clean_block(block)
self.relative_transformation_rate = 0.0
self.cumulative_change_rate = 0.0
self.initial_step = True
self.step_count = 0
self.predict_ratios = []
self.skip_current_step = False
self.predict_start_index = 0
self.x_prev_subsampled = None
self.output_prev_subsampled = None
self.output_prev_norm = None
self.steps_skipped = []
self.output_change_rates = []
self.approx_output_change_rates = []
self.state_metadata = None
self.all_blocks = []
self.blocks_per_type = {}
return self
def clone(self):
return EasySortblockHolder(self.reuse_threshold, self.start_predict_ratio, self.end_predict_ratio, self.max_skipped_steps,
self.start_percent, self.end_percent, self.subsample_factor, self.verbose)
class EasySortblockScaledNode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="EasySortblockScaled",
display_name="EasySortblockScaled",
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
category="advanced/debug/model",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to add Sortblock to."),
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."),
io.Float.Input("start_predict_ratio", min=0.0, default=0.2, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."),
io.Float.Input("end_predict_ratio", min=0.0, default=0.9, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."),
io.Int.Input("policy_refresh_interval", min=3, default=5, max=100, step=1, tooltip="The interval at which to refresh the policy."),
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of Sortblock."),
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of Sortblock."),
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
],
outputs=[
io.Model.Output(tooltip="The model with Sortblock."),
],
)
@classmethod
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_predict_ratio: float, end_predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
# TODO: check for specific flavors of supported models
model = model.clone()
model.model_options["transformer_options"]["easycache"] = EasySortblockHolder(reuse_threshold, start_predict_ratio, end_predict_ratio, policy_refresh_interval, start_percent, end_percent, subsample_factor=8, verbose=verbose)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "sortblock", easysortblock_predict_noise_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "sortblock", easysortblock_outer_sample_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "sortblock", model_forward_wrapper)
return io.NodeOutput(model)
class EasySortblockExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
# EasySortblockNode,
EasySortblockScaledNode,
]
def comfy_entrypoint():
return EasySortblockExtension()

View File

@@ -1,6 +1,7 @@
import comfy.utils
import comfy_extras.nodes_post_processing
import torch
import nodes
def reshape_latent_to(target_shape, latent, repeat_batch=True):
@@ -105,6 +106,73 @@ class LatentInterpolate:
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,)
class LatentConcat:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",), "dim": (["x", "-x", "y", "-y", "t", "-t"], )}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, dim):
samples_out = samples1.copy()
s1 = samples1["samples"]
s2 = samples2["samples"]
s2 = comfy.utils.repeat_to_batch_size(s2, s1.shape[0])
if "-" in dim:
c = (s2, s1)
else:
c = (s1, s2)
if "x" in dim:
dim = -1
elif "y" in dim:
dim = -2
elif "t" in dim:
dim = -3
samples_out["samples"] = torch.cat(c, dim=dim)
return (samples_out,)
class LatentCut:
@classmethod
def INPUT_TYPES(s):
return {"required": {"samples": ("LATENT",),
"dim": (["x", "y", "t"], ),
"index": ("INT", {"default": 0, "min": -nodes.MAX_RESOLUTION, "max": nodes.MAX_RESOLUTION, "step": 1}),
"amount": ("INT", {"default": 1, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 1})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples, dim, index, amount):
samples_out = samples.copy()
s1 = samples["samples"]
if "x" in dim:
dim = s1.ndim - 1
elif "y" in dim:
dim = s1.ndim - 2
elif "t" in dim:
dim = s1.ndim - 3
if index >= 0:
index = min(index, s1.shape[dim] - 1)
amount = min(s1.shape[dim] - index, amount)
else:
index = max(index, -s1.shape[dim])
amount = min(-index, amount)
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
return (samples_out,)
class LatentBatch:
@classmethod
def INPUT_TYPES(s):
@@ -279,6 +347,8 @@ NODE_CLASS_MAPPINGS = {
"LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate,
"LatentConcat": LatentConcat,
"LatentCut": LatentCut,
"LatentBatch": LatentBatch,
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
"LatentApplyOperation": LatentApplyOperation,

View File

@@ -166,7 +166,7 @@ class LTXVAddGuide:
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1),
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
1.0 - strength,
dtype=noise_mask.dtype,
device=noise_mask.device,

View File

@@ -0,0 +1,163 @@
import torch
import folder_paths
import comfy.utils
import comfy.ops
import comfy.model_management
import comfy.ldm.common_dit
import comfy.latent_formats
class BlockWiseControlBlock(torch.nn.Module):
# [linear, gelu, linear]
def __init__(self, dim: int = 3072, device=None, dtype=None, operations=None):
super().__init__()
self.x_rms = operations.RMSNorm(dim, eps=1e-6)
self.y_rms = operations.RMSNorm(dim, eps=1e-6)
self.input_proj = operations.Linear(dim, dim)
self.act = torch.nn.GELU()
self.output_proj = operations.Linear(dim, dim)
def forward(self, x, y):
x, y = self.x_rms(x), self.y_rms(y)
x = self.input_proj(x + y)
x = self.act(x)
x = self.output_proj(x)
return x
class QwenImageBlockWiseControlNet(torch.nn.Module):
def __init__(
self,
num_layers: int = 60,
in_dim: int = 64,
additional_in_dim: int = 0,
dim: int = 3072,
device=None, dtype=None, operations=None
):
super().__init__()
self.additional_in_dim = additional_in_dim
self.img_in = operations.Linear(in_dim + additional_in_dim, dim, device=device, dtype=dtype)
self.controlnet_blocks = torch.nn.ModuleList(
[
BlockWiseControlBlock(dim, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
]
)
def process_input_latent_image(self, latent_image):
latent_image[:, :16] = comfy.latent_formats.Wan21().process_in(latent_image[:, :16])
patch_size = 2
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(latent_image, (1, patch_size, patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
return self.img_in(hidden_states)
def control_block(self, img, controlnet_conditioning, block_id):
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
class ModelPatchLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "name": (folder_paths.get_filename_list("model_patches"), ),
}}
RETURN_TYPES = ("MODEL_PATCH",)
FUNCTION = "load_model_patch"
EXPERIMENTAL = True
CATEGORY = "advanced/loaders"
def load_model_patch(self, name):
model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name)
sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True)
dtype = comfy.utils.weight_dtype(sd)
# TODO: this node will work with more types of model patches
additional_in_dim = sd["img_in.weight"].shape[1] - 64
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
return (model,)
class DiffSynthCnetPatch:
def __init__(self, model_patch, vae, image, strength, mask=None):
self.model_patch = model_patch
self.vae = vae
self.image = image
self.strength = strength
self.mask = mask
self.encoded_image = model_patch.model.process_input_latent_image(self.encode_latent_cond(image))
self.encoded_image_size = (image.shape[1], image.shape[2])
def encode_latent_cond(self, image):
latent_image = self.vae.encode(image)
if self.model_patch.model.additional_in_dim > 0:
if self.mask is None:
mask_ = torch.ones_like(latent_image)[:, :self.model_patch.model.additional_in_dim // 4]
else:
mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none")
return torch.cat([latent_image, mask_], dim=1)
else:
return latent_image
def __call__(self, kwargs):
x = kwargs.get("x")
img = kwargs.get("img")
block_index = kwargs.get("block_index")
spacial_compression = self.vae.spacial_compression_encode()
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1)))
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
comfy.model_management.load_models_gpu(loaded_models)
img[:, :self.encoded_image.shape[1]] += (self.model_patch.model.control_block(img[:, :self.encoded_image.shape[1]], self.encoded_image.to(img.dtype), block_index) * self.strength)
kwargs['img'] = img
return kwargs
def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
self.encoded_image = self.encoded_image.to(device_or_dtype)
return self
def models(self):
return [self.model_patch]
class QwenImageDiffsynthControlnet:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"model_patch": ("MODEL_PATCH",),
"vae": ("VAE",),
"image": ("IMAGE",),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
},
"optional": {"mask": ("MASK",)}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "diffsynth_controlnet"
EXPERIMENTAL = True
CATEGORY = "advanced/loaders/qwen"
def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None):
model_patched = model.clone()
image = image[:, :, :, :3]
if mask is not None:
if mask.ndim == 3:
mask = mask.unsqueeze(1)
if mask.ndim == 4:
mask = mask.unsqueeze(2)
mask = 1.0 - mask
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
return (model_patched,)
NODE_CLASS_MAPPINGS = {
"ModelPatchLoader": ModelPatchLoader,
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
}

View File

@@ -0,0 +1,462 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union
from comfy_api.latest import io, ComfyExtension
import comfy.patcher_extension
import logging
import torch
import math
import comfy.model_patcher
if TYPE_CHECKING:
from uuid import UUID
def prepare_noise_wrapper(executor, *args, **kwargs):
try:
transformer_options: dict[str] = args[2]["transformer_options"]
sb_holder: SortblockHolder = transformer_options["sortblock"]
if sb_holder.initial_step:
sample_sigmas = transformer_options["sample_sigmas"]
relevant_sigmas = []
# find start and end steps, then use to interpolate between start and end predict ratios
for i,sigma in enumerate(sample_sigmas):
if sb_holder.check_if_within_timesteps(sigma):
relevant_sigmas.append((i, sigma))
start_index = relevant_sigmas[0][0]
end_index = relevant_sigmas[-1][0]
sb_holder.predict_ratios = torch.linspace(sb_holder.start_predict_ratio, sb_holder.end_predict_ratio, end_index - start_index + 1)
sb_holder.predict_start_index = start_index
return executor(*args, **kwargs)
finally:
transformer_options: dict[str] = args[2]["transformer_options"]
sb_holder: SortblockHolder = transformer_options["sortblock"]
sb_holder.step_count += 1
if sb_holder.should_do_sortblock():
sb_holder.active_steps += 1
def outer_sample_wrapper(executor, *args, **kwargs):
try:
logging.info("Sortblock: inside outer_sample!")
guider = executor.class_obj
orig_model_options = guider.model_options
guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options)
# clone and prepare timesteps
sb_holder = guider.model_options["transformer_options"]["sortblock"]
guider.model_options["transformer_options"]["sortblock"] = sb_holder.clone().prepare_timesteps(guider.model_patcher.model.model_sampling)
sb_holder: SortblockHolder = guider.model_options["transformer_options"]["sortblock"]
logging.info(f"Sortblock: enabled - threshold: {sb_holder.start_predict_ratio}, start_percent: {sb_holder.start_percent}, end_percent: {sb_holder.end_percent}")
return executor(*args, **kwargs)
finally:
sb_holder = guider.model_options["transformer_options"]["sortblock"]
logging.info(f"Sortblock: final step count: {sb_holder.step_count}")
sb_holder.reset()
guider.model_options = orig_model_options
def model_forward_wrapper(executor, *args, **kwargs):
# TODO: make work with batches of conds
transformer_options: dict[str] = args[-1]
if not isinstance(transformer_options, dict):
transformer_options = kwargs.get("transformer_options")
if not transformer_options:
transformer_options = args[-2]
sigmas = transformer_options["sigmas"]
sb_holder: SortblockHolder = transformer_options["sortblock"]
sb_holder.update_should_do_sortblock(sigmas)
# if initial step, prepare everything for Sortblock
if sb_holder.initial_step:
logging.info(f"Sortblock: inside model {executor.class_obj.__class__.__name__}")
# TODO: generalize for other models
# these won't stick around past this step; should store on sb_holder instead
logging.info(f"Sortblock: preparing {len(executor.class_obj.double_blocks)} double blocks and {len(executor.class_obj.single_blocks)} single blocks")
if hasattr(executor.class_obj, "double_blocks"):
for block in executor.class_obj.double_blocks:
prepare_block(block, sb_holder)
if hasattr(executor.class_obj, "single_blocks"):
for block in executor.class_obj.single_blocks:
prepare_block(block, sb_holder)
if hasattr(executor.class_obj, "blocks"):
for block in executor.class_obj.block:
prepare_block(block, sb_holder)
# when 0: Initialization(1)
if sb_holder.step_modulus == 0:
logging.info(f"Sortblock: for step {sb_holder.step_count}, all blocks are marked for recomputation")
# all features are computed, input-outputs changes for all DiT blocks are stored for relative step 'k'
sb_holder.activated_steps.append(sb_holder.step_count)
for block in sb_holder.all_blocks:
cache: BlockCache = block.__block_cache
cache.mark_recompute()
# all block operations are performed in forward pass of model
to_return = executor(*args, **kwargs)
# when 1: Select DiT blocks(4)
if sb_holder.step_modulus == 1:
predict_index = max(0, sb_holder.step_count - sb_holder.predict_start_index)
predict_ratio = sb_holder.predict_ratios[predict_index]
logging.info(f"Sortblock: for step {sb_holder.step_count}, selecting blocks for recomputation and prediction, predict_ratio: {predict_ratio}")
reuse_ratio = 1.0 - predict_ratio
for block_type, blocks in sb_holder.blocks_per_type.items():
sorted_blocks = sorted(blocks, key=lambda x: x.__block_cache.cosine_similarity)
threshold_index = int(len(sorted_blocks) * reuse_ratio)
# blocks with lower similarity are marked for recomputation
for block in sorted_blocks[:threshold_index]:
cache: BlockCache = block.__block_cache
cache.mark_recompute()
# blocks with higher similarity are marked for prediction
for block in sorted_blocks[threshold_index:]:
cache: BlockCache = block.__block_cache
cache.mark_predict()
logging.info(f"Sortblock: for {block_type}, selected {len(sorted_blocks[:threshold_index])} blocks for recomputation and {len(sorted_blocks[threshold_index:])} blocks for prediction")
if sb_holder.initial_step:
sb_holder.initial_step = False
return to_return
def block_forward_factory(func, block):
def block_forward_wrapper(*args, **kwargs):
transformer_options: dict[str] = kwargs.get("transformer_options")
sb_holder: SortblockHolder = transformer_options["sortblock"]
cache: BlockCache = block.__block_cache
# make sure stream count is properly set for this block
if sb_holder.initial_step:
sb_holder.add_to_blocks_per_type(block, transformer_options['block'][0])
cache.block_index = transformer_options['block'][1]
cache.stream_count = transformer_options['block'][2]
# do sortblock stuff
if cache.recompute and sb_holder.step_modulus != 1:
# clone relevant inputs
orig_inputs = cache.get_orig_inputs(args, kwargs, clone=True)
# get block outputs
# NOTE: output_raw is expected to have cache.stream_count elements if count is greaater than 1 (double block, etc.)
if cache.stream_count == 1:
zzz = 10
output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]] = func(*args, **kwargs)
# perform derivative approximation;
cache.derivative_approximation(sb_holder, output_raw, orig_inputs)
# if step_modulus is 0, input-output changes for DiT block are stored
if sb_holder.step_modulus == 0:
cache.cache_previous_residual(output_raw, orig_inputs)
else:
# if not to recompute, predict features for current timestep
orig_inputs = cache.get_orig_inputs(args, kwargs, clone=False)
# when 1: Linear Prediction(2)
# if step_modulus is 1, store block residuals as 'current' after applying taylor_formula
if sb_holder.step_modulus == 1:
cache.cache_current_residual(sb_holder)
# based on features computed in last timestep, all features for current timestep are predicted using Eq. 4,
# input-output changes for all DiT blocks are stored for relative step 'k+1'
output_raw = cache.apply_linear_prediction(sb_holder, orig_inputs)
# when 1: Identify Changes(3)
if sb_holder.step_modulus == 1:
# based on features computed in last timestep, all features for current timestep are predicted using Eq. 4,
# input-output changes for all DiT blocks are stored for relative step 'k+1'
cache.calculate_cosine_similarity()
# return output_raw
return output_raw
return block_forward_wrapper
def perform_sortblock(blocks: list):
...
def prepare_block(block, sb_holder: SortblockHolder, stream_count: int=1):
sb_holder.add_to_all_blocks(block)
block.__original_forward = block.forward
block.forward = block_forward_factory(block.__original_forward, block)
block.__block_cache = BlockCache(subsample_factor=sb_holder.subsample_factor, verbose=sb_holder.verbose)
def clean_block(block):
block.forward = block.__original_forward
del block.__original_forward
del block.__block_cache
def subsample(x: torch.Tensor, factor: int, clone: bool=True) -> torch.Tensor:
if factor > 1:
to_return = x[..., ::factor, ::factor]
if clone:
return to_return.clone()
return to_return
if clone:
return x.clone()
return x
class BlockCache:
def __init__(self, subsample_factor: int=8, verbose: bool=False):
self.subsample_factor = subsample_factor
self.verbose = verbose
self.stream_count = 1
self.recompute = False
self.block_index = 0
# cached values
self.previous_residual_subsampled: torch.Tensor = None
self.current_residual_subsampled: torch.Tensor = None
self.cosine_similarity: float = None
self.previous_taylor_factors: dict[int, torch.Tensor] = {}
self.current_taylor_factors: dict[int, torch.Tensor] = {}
def mark_recompute(self):
self.recompute = True
def mark_predict(self):
self.recompute = False
def cache_previous_residual(self, output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]], orig_inputs: Union[torch.Tensor, tuple[torch.Tensor, ...]]):
if isinstance(output_raw, tuple):
output_raw = output_raw[0]
if isinstance(orig_inputs, tuple):
orig_inputs = orig_inputs[0]
del self.previous_residual_subsampled
self.previous_residual_subsampled = subsample(output_raw - orig_inputs, self.subsample_factor, clone=True)
def cache_current_residual(self, sb_holder: SortblockHolder):
del self.current_residual_subsampled
self.current_residual_subsampled = subsample(self.use_taylor_formula(sb_holder)[0], self.subsample_factor, clone=True)
def get_orig_inputs(self, d_args: tuple, d_kwargs: dict, clone: bool=True) -> tuple[torch.Tensor, ...]:
if self.stream_count == 1:
if clone:
return d_args[0].clone()
return d_args[0]
keys = list(d_kwargs.keys())[:self.stream_count]
orig_inputs = []
for key in keys:
if clone:
orig_inputs.append(d_kwargs[key].clone())
else:
orig_inputs.append(d_kwargs[key])
return tuple(orig_inputs)
def apply_linear_prediction(self, sb_holder: SortblockHolder, orig_inputs: Union[torch.Tensor, tuple[torch.Tensor, ...]]) -> None:
drop_tuple = False
if not isinstance(orig_inputs, tuple):
orig_inputs = (orig_inputs,)
drop_tuple = True
taylor_results = self.use_taylor_formula(sb_holder)
for output, taylor_result in zip(orig_inputs, taylor_results):
if output.shape != taylor_result.shape:
zzz = 10
output += taylor_result
if drop_tuple:
orig_inputs = orig_inputs[0]
return orig_inputs
def calculate_cosine_similarity(self) -> None:
self.cosine_similarity = torch.nn.functional.cosine_similarity(self.previous_residual_subsampled, self.current_residual_subsampled, dim=-1).mean().item()
def derivative_approximation(self, sb_holder: SortblockHolder, output_raw: Union[torch.Tensor, tuple[torch.Tensor, ...]], orig_inputs: Union[torch.Tensor, tuple[torch.Tensor, ...]]):
activation_distance = sb_holder.activated_steps[-1] - sb_holder.activated_steps[-2]
# make tuple if not already tuple, so that works with both single and double blocks
if not isinstance(output_raw, tuple):
output_raw = (output_raw,)
if not isinstance(orig_inputs, tuple):
orig_inputs = (orig_inputs,)
for i, (output, x) in enumerate(zip(output_raw, orig_inputs)):
feature = output.clone() - x
has_previous_taylor_factor = self.previous_taylor_factors.get(i, None) is not None
# NOTE: not sure why - 2, but that's what's in the original implementation. Maybe consider changing values?
if has_previous_taylor_factor and sb_holder.step_count > (sb_holder.first_enhance - 2):
self.current_taylor_factors[i] = (
feature - self.previous_taylor_factors[i]
) / activation_distance
self.previous_taylor_factors[i] = feature
def use_taylor_formula(self, sb_holder: SortblockHolder) -> tuple[torch.Tensor, ...]:
step_distance = sb_holder.step_count - sb_holder.activated_steps[-1]
output_predicted = []
for key in self.previous_taylor_factors.keys():
previous_tf = self.previous_taylor_factors[key]
current_tf = self.current_taylor_factors[key]
predicted = taylor_formula(previous_tf, 0, step_distance)
predicted += taylor_formula(current_tf, 1, step_distance)
output_predicted.append(predicted)
return tuple(output_predicted)
def reset(self):
self.recompute = False
self.current_residual_subsampled = None
self.previous_residual_subsampled = None
self.cosine_similarity = None
self.previous_taylor_factors = {}
self.current_taylor_factors = {}
def taylor_formula(taylor_factor: torch.Tensor, i: int, step_distance: int):
return (
(1 / math.factorial(i))
* taylor_factor
* (step_distance ** i)
)
class SortblockHolder:
def __init__(self, start_predict_ratio: float, end_predict_ratio: float, policy_refresh_interval: int,
start_percent: float, end_percent: float, subsample_factor: int=8, verbose: bool=False):
self.start_predict_ratio = start_predict_ratio
self.end_predict_ratio = end_predict_ratio
self.start_percent = start_percent
self.end_percent = end_percent
self.subsample_factor = subsample_factor
self.verbose = verbose
# NOTE: number represents steps
self.policy_refresh_interval = policy_refresh_interval
self.active_policy_refresh_interval = 1
self.first_enhance = 3 # NOTE: this value is 2 higher than the one actually used in code (subtracted by 2 in derivative_approximation)
# timestep values
self.start_t = 0.0
self.end_t = 0.0
self.curr_t = 0.0
# control values
self.initial_step = True
self.step_count = 0
self.activated_steps: list[int] = [0]
self.step_modulus = 0
self.do_sortblock = False
self.active_steps = 0
self.predict_ratios = []
self.predict_start_index = 0
# cache values
self.all_blocks = []
self.blocks_per_type = {}
def add_to_all_blocks(self, block):
self.all_blocks.append(block)
def add_to_blocks_per_type(self, block, block_type: str):
self.blocks_per_type.setdefault(block_type, []).append(block)
def prepare_timesteps(self, model_sampling):
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
return self
def check_if_within_timesteps(self, timestep: Union[float, torch.Tensor]) -> bool:
return (timestep <= self.start_t).item() and (timestep > self.end_t).item()
def update_should_do_sortblock(self, timestep: float) -> bool:
self.do_sortblock = (timestep[0] <= self.start_t).item() and (timestep[0] > self.end_t).item()
self.curr_t = timestep
if self.do_sortblock:
self.active_policy_refresh_interval = self.policy_refresh_interval
else:
self.active_policy_refresh_interval = 1
self.update_step_modulus()
return self.do_sortblock
def update_step_modulus(self):
self.step_modulus = int(self.step_count % self.active_policy_refresh_interval)
def should_do_sortblock(self) -> bool:
return self.do_sortblock
def reset(self):
self.initial_step = True
self.curr_t = 0.0
logging.info(f"Sortblock: resetting {len(self.all_blocks)} blocks")
for block in self.all_blocks:
clean_block(block)
self.all_blocks = []
self.blocks_per_type = {}
self.step_count = 0
self.activated_steps = [0]
self.step_modulus = 0
self.active_steps = 0
self.predict_ratios = []
self.do_sortblock = False
self.predict_start_index = 0
return self
def clone(self):
return SortblockHolder(start_predict_ratio=self.start_predict_ratio, end_predict_ratio=self.end_predict_ratio, policy_refresh_interval=self.policy_refresh_interval,
start_percent=self.start_percent, end_percent=self.end_percent, subsample_factor=self.subsample_factor,
verbose=self.verbose)
class SortblockNode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="Sortblock",
display_name="Sortblock",
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
category="advanced/debug/model",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to add Sortblock to."),
io.Float.Input("predict_ratio", min=0.0, default=0.8, max=3.0, step=0.01, tooltip="The ratio of blocks to predict."),
io.Int.Input("policy_refresh_interval", min=3, default=5, max=100, step=1, tooltip="The interval at which to refresh the policy."),
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of Sortblock."),
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of Sortblock."),
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
],
outputs=[
io.Model.Output(tooltip="The model with Sortblock."),
],
)
@classmethod
def execute(cls, model: io.Model.Type, predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
# TODO: check for specific flavors of supported models
model = model.clone()
model.model_options["transformer_options"]["sortblock"] = SortblockHolder(start_predict_ratio=predict_ratio, end_predict_ratio=predict_ratio, policy_refresh_interval=policy_refresh_interval,
start_percent=start_percent, end_percent=end_percent, subsample_factor=8, verbose=verbose)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "sortblock", prepare_noise_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "sortblock", outer_sample_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "sortblock", model_forward_wrapper)
return io.NodeOutput(model)
class SortblockScaledNode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="SortblockScaled",
display_name="SortblockScaled",
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
category="advanced/debug/model",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to add Sortblock to."),
io.Float.Input("start_predict_ratio", min=0.0, default=0.2, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."),
io.Float.Input("end_predict_ratio", min=0.0, default=0.9, max=1.0, step=0.01, tooltip="The ratio of blocks to predict."),
io.Int.Input("policy_refresh_interval", min=3, default=5, max=100, step=1, tooltip="The interval at which to refresh the policy."),
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of Sortblock."),
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of Sortblock."),
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
],
outputs=[
io.Model.Output(tooltip="The model with Sortblock."),
],
)
@classmethod
def execute(cls, model: io.Model.Type, start_predict_ratio: float, end_predict_ratio: float, policy_refresh_interval: int, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
# TODO: check for specific flavors of supported models
model = model.clone()
model.model_options["transformer_options"]["sortblock"] = SortblockHolder(start_predict_ratio, end_predict_ratio, policy_refresh_interval, start_percent, end_percent, subsample_factor=8, verbose=verbose)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "sortblock", prepare_noise_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "sortblock", outer_sample_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "sortblock", model_forward_wrapper)
return io.NodeOutput(model)
class SortblockExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SortblockNode,
SortblockScaledNode,
]
def comfy_entrypoint():
return SortblockExtension()

View File

@@ -1,77 +1,91 @@
import re
from typing_extensions import override
from comfy.comfy_types.node_typing import IO
from comfy_api.latest import ComfyExtension, io
class StringConcatenate():
class StringConcatenate(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string_a": (IO.STRING, {"multiline": True}),
"string_b": (IO.STRING, {"multiline": True}),
"delimiter": (IO.STRING, {"multiline": False, "default": ""})
}
}
def define_schema(cls):
return io.Schema(
node_id="StringConcatenate",
display_name="Concatenate",
category="utils/string",
inputs=[
io.String.Input("string_a", multiline=True),
io.String.Input("string_b", multiline=True),
io.String.Input("delimiter", multiline=False, default=""),
],
outputs=[
io.String.Output(),
]
)
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string_a, string_b, delimiter, **kwargs):
return delimiter.join((string_a, string_b)),
class StringSubstring():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"start": (IO.INT, {}),
"end": (IO.INT, {}),
}
}
def execute(cls, string_a, string_b, delimiter):
return io.NodeOutput(delimiter.join((string_a, string_b)))
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, start, end, **kwargs):
return string[start:end],
class StringLength():
class StringSubstring(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True})
}
}
def define_schema(cls):
return io.Schema(
node_id="StringSubstring",
display_name="Substring",
category="utils/string",
inputs=[
io.String.Input("string", multiline=True),
io.Int.Input("start"),
io.Int.Input("end"),
],
outputs=[
io.String.Output(),
]
)
RETURN_TYPES = (IO.INT,)
RETURN_NAMES = ("length",)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, **kwargs):
length = len(string)
return length,
class CaseConverter():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"mode": (IO.COMBO, {"options": ["UPPERCASE", "lowercase", "Capitalize", "Title Case"]})
}
}
def execute(cls, string, start, end):
return io.NodeOutput(string[start:end])
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, mode, **kwargs):
class StringLength(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StringLength",
display_name="Length",
category="utils/string",
inputs=[
io.String.Input("string", multiline=True),
],
outputs=[
io.Int.Output(display_name="length"),
]
)
@classmethod
def execute(cls, string):
return io.NodeOutput(len(string))
class CaseConverter(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CaseConverter",
display_name="Case Converter",
category="utils/string",
inputs=[
io.String.Input("string", multiline=True),
io.Combo.Input("mode", options=["UPPERCASE", "lowercase", "Capitalize", "Title Case"]),
],
outputs=[
io.String.Output(),
]
)
@classmethod
def execute(cls, string, mode):
if mode == "UPPERCASE":
result = string.upper()
elif mode == "lowercase":
@@ -83,24 +97,27 @@ class CaseConverter():
else:
result = string
return result,
return io.NodeOutput(result)
class StringTrim():
class StringTrim(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"mode": (IO.COMBO, {"options": ["Both", "Left", "Right"]})
}
}
def define_schema(cls):
return io.Schema(
node_id="StringTrim",
display_name="Trim",
category="utils/string",
inputs=[
io.String.Input("string", multiline=True),
io.Combo.Input("mode", options=["Both", "Left", "Right"]),
],
outputs=[
io.String.Output(),
]
)
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, mode, **kwargs):
@classmethod
def execute(cls, string, mode):
if mode == "Both":
result = string.strip()
elif mode == "Left":
@@ -110,70 +127,78 @@ class StringTrim():
else:
result = string
return result,
return io.NodeOutput(result)
class StringReplace():
class StringReplace(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"find": (IO.STRING, {"multiline": True}),
"replace": (IO.STRING, {"multiline": True})
}
}
def define_schema(cls):
return io.Schema(
node_id="StringReplace",
display_name="Replace",
category="utils/string",
inputs=[
io.String.Input("string", multiline=True),
io.String.Input("find", multiline=True),
io.String.Input("replace", multiline=True),
],
outputs=[
io.String.Output(),
]
)
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, find, replace, **kwargs):
result = string.replace(find, replace)
return result,
class StringContains():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"substring": (IO.STRING, {"multiline": True}),
"case_sensitive": (IO.BOOLEAN, {"default": True})
}
}
def execute(cls, string, find, replace):
return io.NodeOutput(string.replace(find, replace))
RETURN_TYPES = (IO.BOOLEAN,)
RETURN_NAMES = ("contains",)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, substring, case_sensitive, **kwargs):
class StringContains(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StringContains",
display_name="Contains",
category="utils/string",
inputs=[
io.String.Input("string", multiline=True),
io.String.Input("substring", multiline=True),
io.Boolean.Input("case_sensitive", default=True),
],
outputs=[
io.Boolean.Output(display_name="contains"),
]
)
@classmethod
def execute(cls, string, substring, case_sensitive):
if case_sensitive:
contains = substring in string
else:
contains = substring.lower() in string.lower()
return contains,
return io.NodeOutput(contains)
class StringCompare():
class StringCompare(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string_a": (IO.STRING, {"multiline": True}),
"string_b": (IO.STRING, {"multiline": True}),
"mode": (IO.COMBO, {"options": ["Starts With", "Ends With", "Equal"]}),
"case_sensitive": (IO.BOOLEAN, {"default": True})
}
}
def define_schema(cls):
return io.Schema(
node_id="StringCompare",
display_name="Compare",
category="utils/string",
inputs=[
io.String.Input("string_a", multiline=True),
io.String.Input("string_b", multiline=True),
io.Combo.Input("mode", options=["Starts With", "Ends With", "Equal"]),
io.Boolean.Input("case_sensitive", default=True),
],
outputs=[
io.Boolean.Output(),
]
)
RETURN_TYPES = (IO.BOOLEAN,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string_a, string_b, mode, case_sensitive, **kwargs):
@classmethod
def execute(cls, string_a, string_b, mode, case_sensitive):
if case_sensitive:
a = string_a
b = string_b
@@ -182,31 +207,34 @@ class StringCompare():
b = string_b.lower()
if mode == "Equal":
return a == b,
return io.NodeOutput(a == b)
elif mode == "Starts With":
return a.startswith(b),
return io.NodeOutput(a.startswith(b))
elif mode == "Ends With":
return a.endswith(b),
return io.NodeOutput(a.endswith(b))
class RegexMatch():
class RegexMatch(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"regex_pattern": (IO.STRING, {"multiline": True}),
"case_insensitive": (IO.BOOLEAN, {"default": True}),
"multiline": (IO.BOOLEAN, {"default": False}),
"dotall": (IO.BOOLEAN, {"default": False})
}
}
def define_schema(cls):
return io.Schema(
node_id="RegexMatch",
display_name="Regex Match",
category="utils/string",
inputs=[
io.String.Input("string", multiline=True),
io.String.Input("regex_pattern", multiline=True),
io.Boolean.Input("case_insensitive", default=True),
io.Boolean.Input("multiline", default=False),
io.Boolean.Input("dotall", default=False),
],
outputs=[
io.Boolean.Output(display_name="matches"),
]
)
RETURN_TYPES = (IO.BOOLEAN,)
RETURN_NAMES = ("matches",)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, regex_pattern, case_insensitive, multiline, dotall, **kwargs):
@classmethod
def execute(cls, string, regex_pattern, case_insensitive, multiline, dotall):
flags = 0
if case_insensitive:
@@ -223,29 +251,32 @@ class RegexMatch():
except re.error:
result = False
return result,
return io.NodeOutput(result)
class RegexExtract():
class RegexExtract(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"regex_pattern": (IO.STRING, {"multiline": True}),
"mode": (IO.COMBO, {"options": ["First Match", "All Matches", "First Group", "All Groups"]}),
"case_insensitive": (IO.BOOLEAN, {"default": True}),
"multiline": (IO.BOOLEAN, {"default": False}),
"dotall": (IO.BOOLEAN, {"default": False}),
"group_index": (IO.INT, {"default": 1, "min": 0, "max": 100})
}
}
def define_schema(cls):
return io.Schema(
node_id="RegexExtract",
display_name="Regex Extract",
category="utils/string",
inputs=[
io.String.Input("string", multiline=True),
io.String.Input("regex_pattern", multiline=True),
io.Combo.Input("mode", options=["First Match", "All Matches", "First Group", "All Groups"]),
io.Boolean.Input("case_insensitive", default=True),
io.Boolean.Input("multiline", default=False),
io.Boolean.Input("dotall", default=False),
io.Int.Input("group_index", default=1, min=0, max=100),
],
outputs=[
io.String.Output(),
]
)
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index, **kwargs):
@classmethod
def execute(cls, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index):
join_delimiter = "\n"
flags = 0
@@ -294,32 +325,33 @@ class RegexExtract():
except re.error:
result = ""
return result,
return io.NodeOutput(result)
class RegexReplace():
DESCRIPTION = "Find and replace text using regex patterns."
class RegexReplace(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"regex_pattern": (IO.STRING, {"multiline": True}),
"replace": (IO.STRING, {"multiline": True}),
},
"optional": {
"case_insensitive": (IO.BOOLEAN, {"default": True}),
"multiline": (IO.BOOLEAN, {"default": False}),
"dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}),
"count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}),
}
}
def define_schema(cls):
return io.Schema(
node_id="RegexReplace",
display_name="Regex Replace",
category="utils/string",
description="Find and replace text using regex patterns.",
inputs=[
io.String.Input("string", multiline=True),
io.String.Input("regex_pattern", multiline=True),
io.String.Input("replace", multiline=True),
io.Boolean.Input("case_insensitive", default=True, optional=True),
io.Boolean.Input("multiline", default=False, optional=True),
io.Boolean.Input("dotall", default=False, optional=True, tooltip="When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."),
io.Int.Input("count", default=0, min=0, max=100, optional=True, tooltip="Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."),
],
outputs=[
io.String.Output(),
]
)
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs):
@classmethod
def execute(cls, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0):
flags = 0
if case_insensitive:
@@ -329,32 +361,25 @@ class RegexReplace():
if dotall:
flags |= re.DOTALL
result = re.sub(regex_pattern, replace, string, count=count, flags=flags)
return result,
return io.NodeOutput(result)
NODE_CLASS_MAPPINGS = {
"StringConcatenate": StringConcatenate,
"StringSubstring": StringSubstring,
"StringLength": StringLength,
"CaseConverter": CaseConverter,
"StringTrim": StringTrim,
"StringReplace": StringReplace,
"StringContains": StringContains,
"StringCompare": StringCompare,
"RegexMatch": RegexMatch,
"RegexExtract": RegexExtract,
"RegexReplace": RegexReplace,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"StringConcatenate": "Concatenate",
"StringSubstring": "Substring",
"StringLength": "Length",
"CaseConverter": "Case Converter",
"StringTrim": "Trim",
"StringReplace": "Replace",
"StringContains": "Contains",
"StringCompare": "Compare",
"RegexMatch": "Regex Match",
"RegexExtract": "Regex Extract",
"RegexReplace": "Regex Replace",
}
class StringExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
StringConcatenate,
StringSubstring,
StringLength,
CaseConverter,
StringTrim,
StringReplace,
StringContains,
StringCompare,
RegexMatch,
RegexExtract,
RegexReplace,
]
async def comfy_entrypoint() -> StringExtension:
return StringExtension()

View File

@@ -139,16 +139,21 @@ class Wan22FunControlToVideo(io.ComfyNode):
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
spacial_scale = vae.spacial_compression_encode()
latent_channels = vae.latent_channels
latent = torch.zeros([batch_size, latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device())
concat_latent = torch.zeros([batch_size, latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device())
if latent_channels == 48:
concat_latent = comfy.latent_formats.Wan22().process_out(concat_latent)
else:
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
concat_latent = concat_latent.repeat(1, 2, 1, 1, 1)
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image[:, :, :, :3])
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
concat_latent[:,latent_channels:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
mask[:, :, :start_image.shape[0] + 3] = 0.0
ref_latent = None
@@ -159,11 +164,11 @@ class Wan22FunControlToVideo(io.ComfyNode):
if control_video is not None:
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(control_video[:, :, :, :3])
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
concat_latent[:,:latent_channels,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": latent_channels})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": latent_channels})
if ref_latent is not None:
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
@@ -201,7 +206,8 @@ class WanFirstLastFrameToVideo(io.ComfyNode):
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
spacial_scale = vae.spacial_compression_encode()
latent = torch.zeros([batch_size, vae.latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
if end_image is not None:
@@ -786,6 +792,229 @@ class WanTrackToVideo(io.ComfyNode):
return io.NodeOutput(positive, negative, out_latent)
def linear_interpolation(features, input_fps, output_fps, output_len=None):
"""
features: shape=[1, T, 512]
input_fps: fps for audio, f_a
output_fps: fps for video, f_m
output_len: video length
"""
features = features.transpose(1, 2) # [1, 512, T]
seq_len = features.shape[2] / float(input_fps) # T/f_a
if output_len is None:
output_len = int(seq_len * output_fps) # f_m*T/f_a
output_features = torch.nn.functional.interpolate(
features, size=output_len, align_corners=True,
mode='linear') # [1, 512, output_len]
return output_features.transpose(1, 2) # [1, output_len, 512]
def get_sample_indices(original_fps,
total_frames,
target_fps,
num_sample,
fixed_start=None):
required_duration = num_sample / target_fps
required_origin_frames = int(np.ceil(required_duration * original_fps))
if required_duration > total_frames / original_fps:
raise ValueError("required_duration must be less than video length")
if not fixed_start is None and fixed_start >= 0:
start_frame = fixed_start
else:
max_start = total_frames - required_origin_frames
if max_start < 0:
raise ValueError("video length is too short")
start_frame = np.random.randint(0, max_start + 1)
start_time = start_frame / original_fps
end_time = start_time + required_duration
time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)
frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
frame_indices = np.clip(frame_indices, 0, total_frames - 1)
return frame_indices
def get_audio_embed_bucket_fps(audio_embed, fps=16, batch_frames=81, m=0, video_rate=30):
num_layers, audio_frame_num, audio_dim = audio_embed.shape
if num_layers > 1:
return_all_layers = True
else:
return_all_layers = False
scale = video_rate / fps
min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1
bucket_num = min_batch_num * batch_frames
padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * video_rate) - audio_frame_num
batch_idx = get_sample_indices(
original_fps=video_rate,
total_frames=audio_frame_num + padd_audio_num,
target_fps=fps,
num_sample=bucket_num,
fixed_start=0)
batch_audio_eb = []
audio_sample_stride = int(video_rate / fps)
for bi in batch_idx:
if bi < audio_frame_num:
chosen_idx = list(
range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
chosen_idx = [
audio_frame_num - 1 if c >= audio_frame_num else c
for c in chosen_idx
]
if return_all_layers:
frame_audio_embed = audio_embed[:, chosen_idx].flatten(
start_dim=-2, end_dim=-1)
else:
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
else:
frame_audio_embed = torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
batch_audio_eb.append(frame_audio_embed)
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)
return batch_audio_eb, min_batch_num
def wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, frame_offset=0, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None, ref_motion_latent=None):
latent_t = ((length - 1) // 4) + 1
if audio_encoder_output is not None:
feat = torch.cat(audio_encoder_output["encoded_audio_all_layers"])
video_rate = 30
fps = 16
feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate)
batch_frames = latent_t * 4
audio_embed_bucket, num_repeat = get_audio_embed_bucket_fps(feat, fps=fps, batch_frames=batch_frames, m=0, video_rate=video_rate)
audio_embed_bucket = audio_embed_bucket.unsqueeze(0)
if len(audio_embed_bucket.shape) == 3:
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
elif len(audio_embed_bucket.shape) == 4:
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
audio_embed_bucket = audio_embed_bucket[:, :, :, frame_offset:frame_offset + batch_frames]
if audio_embed_bucket.shape[3] > 0:
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket})
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket * 0.0})
frame_offset += batch_frames
if ref_image is not None:
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
ref_latent = vae.encode(ref_image[:, :, :, :3])
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
if ref_motion is not None:
if ref_motion.shape[0] > 73:
ref_motion = ref_motion[-73:]
ref_motion = comfy.utils.common_upscale(ref_motion.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
if ref_motion.shape[0] < 73:
r = torch.ones([73, height, width, 3]) * 0.5
r[-ref_motion.shape[0]:] = ref_motion
ref_motion = r
ref_motion_latent = vae.encode(ref_motion[:, :, :, :3])
if ref_motion_latent is not None:
ref_motion_latent = ref_motion_latent[:, :, -19:]
positive = node_helpers.conditioning_set_values(positive, {"reference_motion": ref_motion_latent})
negative = node_helpers.conditioning_set_values(negative, {"reference_motion": ref_motion_latent})
latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device())
control_video_out = comfy.latent_formats.Wan21().process_out(torch.zeros_like(latent))
if control_video is not None:
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
control_video = vae.encode(control_video[:, :, :, :3])
control_video_out[:, :, :control_video.shape[2]] = control_video
# TODO: check if zero is better than none if none provided
positive = node_helpers.conditioning_set_values(positive, {"control_video": control_video_out})
negative = node_helpers.conditioning_set_values(negative, {"control_video": control_video_out})
out_latent = {}
out_latent["samples"] = latent
return positive, negative, out_latent, frame_offset
class WanSoundImageToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanSoundImageToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.AudioEncoderOutput.Input("audio_encoder_output", optional=True),
io.Image.Input("ref_image", optional=True),
io.Image.Input("control_video", optional=True),
io.Image.Input("ref_motion", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
is_experimental=True,
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None) -> io.NodeOutput:
positive, negative, out_latent, frame_offset = wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, ref_image=ref_image, audio_encoder_output=audio_encoder_output,
control_video=control_video, ref_motion=ref_motion)
return io.NodeOutput(positive, negative, out_latent)
class WanSoundImageToVideoExtend(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanSoundImageToVideoExtend",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Latent.Input("video_latent"),
io.AudioEncoderOutput.Input("audio_encoder_output", optional=True),
io.Image.Input("ref_image", optional=True),
io.Image.Input("control_video", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
is_experimental=True,
)
@classmethod
def execute(cls, positive, negative, vae, length, video_latent, ref_image=None, audio_encoder_output=None, control_video=None) -> io.NodeOutput:
video_latent = video_latent["samples"]
width = video_latent.shape[-1] * 8
height = video_latent.shape[-2] * 8
batch_size = video_latent.shape[0]
frame_offset = video_latent.shape[-3] * 4
positive, negative, out_latent, frame_offset = wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, frame_offset=frame_offset, ref_image=ref_image, audio_encoder_output=audio_encoder_output,
control_video=control_video, ref_motion=None, ref_motion_latent=video_latent)
return io.NodeOutput(positive, negative, out_latent)
class Wan22ImageToVideoLatent(io.ComfyNode):
@classmethod
def define_schema(cls):
@@ -844,6 +1073,8 @@ class WanExtension(ComfyExtension):
TrimVideoLatent,
WanCameraImageToVideo,
WanPhantomSubjectToVideo,
WanSoundImageToVideo,
WanSoundImageToVideoExtend,
Wan22ImageToVideoLatent,
]

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.51"
__version__ = "0.3.56"

View File

@@ -46,6 +46,10 @@ folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")]
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patches")], supported_pt_extensions)
folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions)
output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp")
input_directory = os.path.join(base_path, "input")

View File

@@ -112,6 +112,7 @@ import gc
if os.name == "nt":
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
if __name__ == "__main__":

View File

@@ -2322,6 +2322,11 @@ async def init_builtin_extra_nodes():
"nodes_tcfg.py",
"nodes_context_windows.py",
"nodes_qwen.py",
"nodes_model_patch.py",
"nodes_easycache.py",
"nodes_audio_encoder.py",
"nodes_sortblock.py",
"nodes_easysortblock.py",
]
import_failed = []

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.51"
version = "0.3.56"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.25.9
comfyui-workflow-templates==0.1.62
comfyui-frontend-package==1.25.11
comfyui-workflow-templates==0.1.70
comfyui-embedded-docs==0.2.6
torch
torchsde