mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 05:30:03 +00:00
Compare commits
35 Commits
node-memor
...
v0.3.52
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
71ed4a399e | ||
|
|
3e316c6338 | ||
|
|
8be0d22ab7 | ||
|
|
59eddda900 | ||
|
|
41048c69b4 | ||
|
|
fc247150fe | ||
|
|
fe31ad0276 | ||
|
|
ca4e96a8ae | ||
|
|
050c67323c | ||
|
|
497d41fb50 | ||
|
|
ff57793659 | ||
|
|
f7bd5e58dd | ||
|
|
7ed73d12d1 | ||
|
|
eb39019daa | ||
|
|
bab08f40d1 | ||
|
|
bc49106837 | ||
|
|
1b2de2642d | ||
|
|
9fa1036f60 | ||
|
|
0737b7e0d2 | ||
|
|
0963493a9c | ||
|
|
e73a9dbe30 | ||
|
|
fe01885acf | ||
|
|
7139d6d93f | ||
|
|
2f52e8f05f | ||
|
|
8d38ea3bbf | ||
|
|
5a8f502db5 | ||
|
|
7cd2c4bd6a | ||
|
|
dfa791eb4b | ||
|
|
bddd69618b | ||
|
|
54d8fdbed0 | ||
|
|
d844d8b13b | ||
|
|
07a927517c | ||
|
|
f16a70ba67 | ||
|
|
36b5127fd3 | ||
|
|
4977f203fa |
@@ -71,6 +71,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
|
||||
- [Qwen Image Edit](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/#edit-model)
|
||||
- Video Models
|
||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
@@ -191,7 +192,7 @@ comfy install
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
python 3.13 is supported but using 3.12 is recommended because some custom nodes and their dependencies might not support it yet.
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies you can try 3.12
|
||||
|
||||
Git clone this repo.
|
||||
|
||||
|
||||
@@ -363,10 +363,17 @@ class UserManager():
|
||||
if not overwrite and os.path.exists(path):
|
||||
return web.Response(status=409, text="File already exists")
|
||||
|
||||
body = await request.read()
|
||||
try:
|
||||
body = await request.read()
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
except OSError as e:
|
||||
logging.warning(f"Error saving file '{path}': {e}")
|
||||
return web.Response(
|
||||
status=400,
|
||||
reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|"
|
||||
)
|
||||
|
||||
user_path = self.get_request_user_filepath(request, None)
|
||||
if full_info:
|
||||
|
||||
@@ -97,7 +97,7 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
|
||||
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]):
|
||||
if embeds is not None:
|
||||
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
|
||||
else:
|
||||
|
||||
@@ -36,6 +36,7 @@ import comfy.ldm.cascade.controlnet
|
||||
import comfy.cldm.mmdit
|
||||
import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet
|
||||
import comfy.ldm.qwen_image.controlnet
|
||||
import comfy.cldm.dit_embedder
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
@@ -236,11 +237,11 @@ class ControlNet(ControlBase):
|
||||
self.cond_hint = None
|
||||
compression_ratio = self.compression_ratio
|
||||
if self.vae is not None:
|
||||
compression_ratio *= self.vae.downscale_ratio
|
||||
compression_ratio *= self.vae.spacial_compression_encode()
|
||||
else:
|
||||
if self.latent_format is not None:
|
||||
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center")
|
||||
self.cond_hint = self.preprocess_image(self.cond_hint)
|
||||
if self.vae is not None:
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
@@ -582,6 +583,15 @@ def load_controlnet_flux_instantx(sd, model_options={}):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
extra_conds = []
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
@@ -655,8 +665,11 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
|
||||
else:
|
||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||
elif "transformer_blocks.0.img_mlp.net.0.proj.weight" in controlnet_data:
|
||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from .attention import LinearTransformerBlock, t2i_modulate
|
||||
@@ -343,7 +344,28 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
def forward(self,
|
||||
x,
|
||||
timestep,
|
||||
attention_mask=None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
lyrics_strength=1.0,
|
||||
**kwargs
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timestep, attention_mask, context, text_attention_mask, speaker_embeds, lyric_token_idx, lyric_mask, block_controlnet_hidden_states,
|
||||
controlnet_scale, lyrics_strength, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
|
||||
@@ -9,6 +9,7 @@ import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ops
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
@@ -436,6 +437,13 @@ class MMDiT(nn.Module):
|
||||
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
||||
|
||||
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
# patchify x, add PE
|
||||
b, c, h, w = x.shape
|
||||
|
||||
@@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
@@ -253,6 +254,13 @@ class Chroma(nn.Module):
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, guidance, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ from torchvision import transforms
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
import comfy.patcher_extension
|
||||
|
||||
from .blocks import (
|
||||
FinalLayer,
|
||||
GeneralDITTransformerBlock,
|
||||
@@ -435,6 +437,42 @@ class GeneralDIT(nn.Module):
|
||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x,
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask,
|
||||
fps,
|
||||
image_size,
|
||||
padding_mask,
|
||||
scalar_feature,
|
||||
data_type,
|
||||
latent_condition,
|
||||
latent_condition_sigma,
|
||||
condition_video_augment_sigma,
|
||||
**kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
# crossattn_emb: torch.Tensor,
|
||||
# crossattn_mask: Optional[torch.Tensor] = None,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
image_size: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
scalar_feature: Optional[torch.Tensor] = None,
|
||||
data_type: Optional[DataType] = DataType.VIDEO,
|
||||
latent_condition: Optional[torch.Tensor] = None,
|
||||
latent_condition_sigma: Optional[torch.Tensor] = None,
|
||||
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -11,6 +11,7 @@ import math
|
||||
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
|
||||
from torchvision import transforms
|
||||
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
@@ -805,7 +806,21 @@ class MiniTrainDIT(nn.Module):
|
||||
)
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
def forward(
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timesteps, context, fps, padding_mask, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
@@ -214,6 +215,13 @@ class Flux(nn.Module):
|
||||
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, guidance, ref_latents, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h_orig, w_orig = x.shape
|
||||
patch_size = self.patch_size
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from comfy.ldm.flux.layers import LastLayer
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
@@ -692,7 +693,23 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
||||
raise NotImplementedError
|
||||
return x, x_masks, img_sizes
|
||||
|
||||
def forward(
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states_llama3=None,
|
||||
image_cond=None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, t, y, context, encoder_hidden_states_llama3, image_cond, control, transformer_options)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
|
||||
@@ -7,6 +7,7 @@ from comfy.ldm.flux.layers import (
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
class Hunyuan3Dv2(nn.Module):
|
||||
@@ -67,6 +68,13 @@ class Hunyuan3Dv2(nn.Module):
|
||||
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, guidance, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
||||
x = x.movedim(-1, -2)
|
||||
timestep = 1.0 - timestep
|
||||
txt = context
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#Based on Flux code because of weird hunyuan video code license.
|
||||
|
||||
import torch
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.flux.layers
|
||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
@@ -348,6 +349,13 @@ class HunyuanVideo(nn.Module):
|
||||
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
img_ids = self.img_ids(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.modules.attention
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
@@ -420,6 +421,13 @@ class LTXVModel(torch.nn.Module):
|
||||
self.patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
|
||||
@@ -11,6 +11,7 @@ import comfy.ldm.common_dit
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
def modulate(x, scale):
|
||||
@@ -590,8 +591,15 @@ class NextDiT(nn.Module):
|
||||
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
t = 1.0 - timesteps
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
|
||||
@@ -109,7 +109,7 @@ class PatchEmbed(nn.Module):
|
||||
def modulate(x, shift, scale):
|
||||
if shift is None:
|
||||
shift = torch.zeros_like(scale)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
return torch.addcmul(shift.unsqueeze(1), x, 1+ scale.unsqueeze(1))
|
||||
|
||||
|
||||
#################################################################################
|
||||
@@ -564,10 +564,7 @@ class DismantledBlock(nn.Module):
|
||||
assert not self.pre_only
|
||||
attn1 = self.attn.post_attention(attn)
|
||||
attn2 = self.attn2.post_attention(attn2)
|
||||
out1 = gate_msa.unsqueeze(1) * attn1
|
||||
out2 = gate_msa2.unsqueeze(1) * attn2
|
||||
x = x + out1
|
||||
x = x + out2
|
||||
x = gate_cat(x, gate_msa, gate_msa2, attn1, attn2)
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
)
|
||||
@@ -594,6 +591,11 @@ class DismantledBlock(nn.Module):
|
||||
)
|
||||
return self.post_attention(attn, *intermediates)
|
||||
|
||||
def gate_cat(x, gate_msa, gate_msa2, attn1, attn2):
|
||||
out1 = gate_msa.unsqueeze(1) * attn1
|
||||
out2 = gate_msa2.unsqueeze(1) * attn2
|
||||
x = torch.stack([x, out1, out2], dim=0).sum(dim=0)
|
||||
return x
|
||||
|
||||
def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||
if use_checkpoint:
|
||||
|
||||
77
comfy/ldm/qwen_image/controlnet.py
Normal file
77
comfy/ldm/qwen_image/controlnet.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
import math
|
||||
|
||||
from .model import QwenImageTransformer2DModel
|
||||
|
||||
|
||||
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||
def __init__(
|
||||
self,
|
||||
extra_condition_channels=0,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||
self.main_model_double = 60
|
||||
|
||||
# controlnet_blocks
|
||||
self.controlnet_blocks = torch.nn.ModuleList([])
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
self.controlnet_blocks.append(operations.Linear(self.inner_dim, self.inner_dim, device=device, dtype=dtype))
|
||||
self.controlnet_x_embedder = operations.Linear(self.in_channels + extra_condition_channels, self.inner_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask=None,
|
||||
guidance: torch.Tensor = None,
|
||||
ref_latents=None,
|
||||
hint=None,
|
||||
transformer_options={},
|
||||
**kwargs
|
||||
):
|
||||
timestep = timesteps
|
||||
encoder_hidden_states = context
|
||||
encoder_hidden_states_mask = attention_mask
|
||||
|
||||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||
hint, _, _ = self.process_img(hint)
|
||||
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance * 1000
|
||||
|
||||
temb = (
|
||||
self.time_text_embed(timestep, hidden_states)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||
)
|
||||
|
||||
repeat = math.ceil(self.main_model_double / len(self.controlnet_blocks))
|
||||
|
||||
controlnet_block_samples = ()
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
controlnet_block_samples = controlnet_block_samples + (self.controlnet_blocks[i](hidden_states),) * repeat
|
||||
|
||||
return {"input": controlnet_block_samples[:self.main_model_double]}
|
||||
@@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||
@@ -214,9 +215,9 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def _modulate(self, x, mod_params):
|
||||
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
||||
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
|
||||
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -248,11 +249,11 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
|
||||
img_normed2 = self.img_norm2(hidden_states)
|
||||
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
||||
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
|
||||
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
||||
|
||||
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
||||
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
|
||||
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
@@ -275,7 +276,7 @@ class LastLayer(nn.Module):
|
||||
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear(self.silu(conditioning_embedding))
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :])
|
||||
return x
|
||||
|
||||
|
||||
@@ -293,6 +294,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||
image_model=None,
|
||||
final_layer=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@@ -300,6 +302,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
@@ -329,9 +332,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.gradient_checkpointing = False
|
||||
if final_layer:
|
||||
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||
bs, c, t, h, w = x.shape
|
||||
@@ -347,13 +350,20 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
|
||||
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
||||
|
||||
def forward(
|
||||
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
timesteps,
|
||||
@@ -362,6 +372,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
guidance: torch.Tensor = None,
|
||||
ref_latents=None,
|
||||
transformer_options={},
|
||||
control=None,
|
||||
**kwargs
|
||||
):
|
||||
timestep = timesteps
|
||||
@@ -396,10 +407,11 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size)))
|
||||
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
@@ -415,6 +427,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
)
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
patches = transformer_options.get("patches", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
@@ -435,6 +448,19 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
hidden_states += add
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
@@ -148,8 +149,8 @@ WAN_CROSSATTENTION_CLASSES = {
|
||||
|
||||
def repeat_e(e, x):
|
||||
repeats = 1
|
||||
if e.shape[1] > 1:
|
||||
repeats = x.shape[1] // e.shape[1]
|
||||
if e.size(1) > 1:
|
||||
repeats = x.size(1) // e.size(1)
|
||||
if repeats == 1:
|
||||
return e
|
||||
return torch.repeat_interleave(e, repeats, dim=1)
|
||||
@@ -219,15 +220,15 @@ class WanAttentionBlock(nn.Module):
|
||||
|
||||
# self-attention
|
||||
y = self.self_attn(
|
||||
self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x),
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
freqs)
|
||||
|
||||
x = x + y * repeat_e(e[2], x)
|
||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||
|
||||
# cross-attention & ffn
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
||||
y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x))
|
||||
x = x + y * repeat_e(e[5], x)
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
return x
|
||||
|
||||
|
||||
@@ -342,7 +343,7 @@ class Head(nn.Module):
|
||||
else:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
|
||||
|
||||
x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x)))
|
||||
x = (self.head(torch.addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x))))
|
||||
return x
|
||||
|
||||
|
||||
@@ -573,6 +574,13 @@ class WanModel(torch.nn.Module):
|
||||
return x
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, clip_fea, time_dim_concat, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
|
||||
@@ -1325,6 +1325,7 @@ class Omnigen2(BaseModel):
|
||||
class QwenImage(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@@ -1342,3 +1343,10 @@ class QwenImage(BaseModel):
|
||||
if ref_latents_method is not None:
|
||||
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
return out
|
||||
|
||||
@@ -492,6 +492,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "qwen_image"
|
||||
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
|
||||
@@ -582,25 +582,24 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
soft_empty_cache()
|
||||
return unloaded_models
|
||||
|
||||
def get_models_memory_reserve(models):
|
||||
total_reserve = 0
|
||||
for model in models:
|
||||
total_reserve += model.get_model_memory_reserve(convert_to_bytes=True)
|
||||
return total_reserve
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
cleanup_models_gc()
|
||||
global vram_state
|
||||
|
||||
inference_memory = minimum_inference_memory()
|
||||
models_memory_reserve = get_models_memory_reserve(models)
|
||||
extra_mem = max(inference_memory + models_memory_reserve, memory_required + extra_reserved_memory() + models_memory_reserve)
|
||||
extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
|
||||
if minimum_memory_required is None:
|
||||
minimum_memory_required = extra_mem
|
||||
else:
|
||||
minimum_memory_required = max(inference_memory + models_memory_reserve, minimum_memory_required + extra_reserved_memory() + models_memory_reserve)
|
||||
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
|
||||
|
||||
models = set(models)
|
||||
models_temp = set()
|
||||
for m in models:
|
||||
models_temp.add(m)
|
||||
for mm in m.model_patches_models():
|
||||
models_temp.add(mm)
|
||||
|
||||
models = models_temp
|
||||
|
||||
models_to_load = []
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ import inspect
|
||||
import logging
|
||||
import math
|
||||
import uuid
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -84,12 +84,6 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
||||
model_options["disable_cfg1_optimization"] = True
|
||||
return model_options
|
||||
|
||||
def add_model_options_memory_reserve(model_options, memory_reserve_gb: float):
|
||||
if "model_memory_reserve" not in model_options:
|
||||
model_options["model_memory_reserve"] = []
|
||||
model_options["model_memory_reserve"].append(memory_reserve_gb)
|
||||
return model_options
|
||||
|
||||
def create_model_options_clone(orig_model_options: dict):
|
||||
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
||||
|
||||
@@ -436,6 +430,9 @@ class ModelPatcher:
|
||||
def set_model_forward_timestep_embed_patch(self, patch):
|
||||
self.set_model_patch(patch, "forward_timestep_embed_patch")
|
||||
|
||||
def set_model_double_block_patch(self, patch):
|
||||
self.set_model_patch(patch, "double_block")
|
||||
|
||||
def add_object_patch(self, name, obj):
|
||||
self.object_patches[name] = obj
|
||||
|
||||
@@ -445,17 +442,6 @@ class ModelPatcher:
|
||||
self.force_cast_weights = True
|
||||
self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this
|
||||
|
||||
def add_model_memory_reserve(self, memory_reserve_gb: float):
|
||||
"""Adds additional expected memory usage for the model, in gigabytes."""
|
||||
self.model_options = add_model_options_memory_reserve(self.model_options, memory_reserve_gb)
|
||||
|
||||
def get_model_memory_reserve(self, convert_to_bytes: bool = False) -> Union[float, int]:
|
||||
"""Returns the total expected memory usage for the model in gigabytes, or bytes if convert_to_bytes is True."""
|
||||
total_reserve = sum(self.model_options.get("model_memory_reserve", []))
|
||||
if convert_to_bytes:
|
||||
return total_reserve * 1024 * 1024 * 1024
|
||||
return total_reserve
|
||||
|
||||
def add_weight_wrapper(self, name, function):
|
||||
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
@@ -503,6 +489,30 @@ class ModelPatcher:
|
||||
if hasattr(wrap_func, "to"):
|
||||
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
||||
|
||||
def model_patches_models(self):
|
||||
to = self.model_options["transformer_options"]
|
||||
models = []
|
||||
if "patches" in to:
|
||||
patches = to["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], "models"):
|
||||
models += patch_list[i].models()
|
||||
if "patches_replace" in to:
|
||||
patches = to["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
if hasattr(patch_list[k], "models"):
|
||||
models += patch_list[k].models()
|
||||
if "model_function_wrapper" in self.model_options:
|
||||
wrap_func = self.model_options["model_function_wrapper"]
|
||||
if hasattr(wrap_func, "models"):
|
||||
models += wrap_func.models()
|
||||
|
||||
return models
|
||||
|
||||
def model_dtype(self):
|
||||
if hasattr(self.model, "get_dtype"):
|
||||
return self.model.get_dtype()
|
||||
|
||||
@@ -50,6 +50,7 @@ class WrappersMP:
|
||||
OUTER_SAMPLE = "outer_sample"
|
||||
PREPARE_SAMPLING = "prepare_sampling"
|
||||
SAMPLER_SAMPLE = "sampler_sample"
|
||||
PREDICT_NOISE = "predict_noise"
|
||||
CALC_COND_BATCH = "calc_cond_batch"
|
||||
APPLY_MODEL = "apply_model"
|
||||
DIFFUSION_MODEL = "diffusion_model"
|
||||
|
||||
@@ -17,6 +17,7 @@ import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import comfy.context_windows
|
||||
import comfy.utils
|
||||
import scipy.stats
|
||||
import numpy
|
||||
|
||||
@@ -61,7 +62,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
if "mask_strength" in conds:
|
||||
mask_strength = conds["mask_strength"]
|
||||
mask = conds['mask']
|
||||
assert (mask.shape[1:] == x_in.shape[2:])
|
||||
# assert (mask.shape[1:] == x_in.shape[2:])
|
||||
|
||||
mask = mask[:input_x.shape[0]]
|
||||
if area is not None:
|
||||
@@ -69,7 +70,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
||||
mask = mask.narrow(i + 1, area[len(dims) + i], area[i])
|
||||
|
||||
mask = mask * mask_strength
|
||||
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
||||
mask = mask.unsqueeze(1).repeat((input_x.shape[0] // mask.shape[0], input_x.shape[1]) + (1, ) * (mask.ndim - 1))
|
||||
else:
|
||||
mask = torch.ones_like(input_x)
|
||||
mult = mask * strength
|
||||
@@ -553,7 +554,10 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
|
||||
if len(mask.shape) == len(dims):
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[1:] != dims:
|
||||
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1)
|
||||
if mask.ndim < 4:
|
||||
mask = comfy.utils.common_upscale(mask.unsqueeze(1), dims[-1], dims[-2], 'bilinear', 'none').squeeze(1)
|
||||
else:
|
||||
mask = comfy.utils.common_upscale(mask, dims[-1], dims[-2], 'bilinear', 'none')
|
||||
|
||||
if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2
|
||||
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
|
||||
@@ -953,7 +957,14 @@ class CFGGuider:
|
||||
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.predict_noise(*args, **kwargs)
|
||||
return self.outer_predict_noise(*args, **kwargs)
|
||||
|
||||
def outer_predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self.predict_noise,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, self.model_options, is_model_options=True)
|
||||
).execute(x, timestep, model_options, seed)
|
||||
|
||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
||||
|
||||
@@ -204,17 +204,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
|
||||
index = 0
|
||||
pad_extra = 0
|
||||
embeds_info = []
|
||||
for o in other_embeds:
|
||||
emb = o[1]
|
||||
if torch.is_tensor(emb):
|
||||
emb = {"type": "embedding", "data": emb}
|
||||
|
||||
extra = None
|
||||
emb_type = emb.get("type", None)
|
||||
if emb_type == "embedding":
|
||||
emb = emb.get("data", None)
|
||||
else:
|
||||
if hasattr(self.transformer, "preprocess_embed"):
|
||||
emb = self.transformer.preprocess_embed(emb, device=device)
|
||||
emb, extra = self.transformer.preprocess_embed(emb, device=device)
|
||||
else:
|
||||
emb = None
|
||||
|
||||
@@ -229,6 +231,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
|
||||
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
|
||||
index += emb_shape - 1
|
||||
embeds_info.append({"type": emb_type, "index": ind, "size": emb_shape, "extra": extra})
|
||||
else:
|
||||
index += -1
|
||||
pad_extra += emb_shape
|
||||
@@ -243,11 +246,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
attention_masks.append(attention_mask)
|
||||
num_tokens.append(sum(attention_mask))
|
||||
|
||||
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens
|
||||
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
|
||||
|
||||
def forward(self, tokens):
|
||||
device = self.transformer.get_input_embeddings().weight.device
|
||||
embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)
|
||||
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
|
||||
|
||||
attention_mask_model = None
|
||||
if self.enable_attention_masks:
|
||||
@@ -258,7 +261,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
else:
|
||||
intermediate_output = self.layer_idx
|
||||
|
||||
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
|
||||
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32, embeds_info=embeds_info)
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs[0].float()
|
||||
@@ -531,7 +534,10 @@ class SDTokenizer:
|
||||
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
||||
|
||||
text = escape_important(text)
|
||||
parsed_weights = token_weights(text, 1.0)
|
||||
if kwargs.get("disable_weights", False):
|
||||
parsed_weights = [(text, 1.0)]
|
||||
else:
|
||||
parsed_weights = token_weights(text, 1.0)
|
||||
|
||||
# tokenize words
|
||||
tokens = []
|
||||
|
||||
@@ -116,7 +116,7 @@ class BertModel_(torch.nn.Module):
|
||||
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
|
||||
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
|
||||
|
||||
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
||||
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
|
||||
@@ -2,12 +2,14 @@ import torch
|
||||
import torch.nn as nn
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
import math
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
import comfy.model_management
|
||||
from . import qwen_vl
|
||||
|
||||
@dataclass
|
||||
class Llama2Config:
|
||||
@@ -25,6 +27,7 @@ class Llama2Config:
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
|
||||
@dataclass
|
||||
class Qwen25_3BConfig:
|
||||
@@ -42,6 +45,7 @@ class Qwen25_3BConfig:
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = True
|
||||
rope_dims = None
|
||||
|
||||
@dataclass
|
||||
class Qwen25_7BVLI_Config:
|
||||
@@ -59,6 +63,7 @@ class Qwen25_7BVLI_Config:
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = True
|
||||
rope_dims = [16, 24, 24]
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
@@ -76,6 +81,7 @@ class Gemma2_2B_Config:
|
||||
rms_norm_add = True
|
||||
mlp_activation = "gelu_pytorch_tanh"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||
@@ -100,24 +106,30 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def precompute_freqs_cis(head_dim, seq_len, theta, device=None):
|
||||
def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None):
|
||||
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
|
||||
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
|
||||
|
||||
position_ids = torch.arange(0, seq_len, device=device).unsqueeze(0)
|
||||
|
||||
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
if rope_dims is not None and position_ids.shape[0] > 1:
|
||||
mrope_section = rope_dims * 2
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
|
||||
else:
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
|
||||
return (cos, sin)
|
||||
|
||||
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
cos = freqs_cis[0].unsqueeze(1)
|
||||
sin = freqs_cis[1].unsqueeze(1)
|
||||
cos = freqs_cis[0]
|
||||
sin = freqs_cis[1]
|
||||
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
||||
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
||||
return q_embed, k_embed
|
||||
@@ -277,7 +289,7 @@ class Llama2_(nn.Module):
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
else:
|
||||
@@ -286,9 +298,13 @@ class Llama2_(nn.Module):
|
||||
if self.normalize_in:
|
||||
x *= self.config.hidden_size ** 0.5
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0)
|
||||
|
||||
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
||||
x.shape[1],
|
||||
position_ids,
|
||||
self.config.rope_theta,
|
||||
self.config.rope_dims,
|
||||
device=x.device)
|
||||
|
||||
mask = None
|
||||
@@ -372,8 +388,38 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
if embed["type"] == "image":
|
||||
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
|
||||
return self.visual(image.to(device, dtype=torch.float32), grid), grid
|
||||
return None, None
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
||||
grid = None
|
||||
for e in embeds_info:
|
||||
if e.get("type") == "image":
|
||||
grid = e.get("extra", None)
|
||||
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
|
||||
start = e.get("index")
|
||||
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||
end = e.get("size") + start
|
||||
len_max = int(grid.max()) // 2
|
||||
start_next = len_max + start
|
||||
position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device)
|
||||
position_ids[0, start:end] = start
|
||||
max_d = int(grid[0][1]) // 2
|
||||
position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||
max_d = int(grid[0][2]) // 2
|
||||
position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
|
||||
|
||||
if grid is None:
|
||||
position_ids = None
|
||||
|
||||
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids)
|
||||
|
||||
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
@@ -15,13 +15,27 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer)
|
||||
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs):
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
if len(images) > 0:
|
||||
llama_text = self.llama_template_images.format(text)
|
||||
else:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
|
||||
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
key_name = next(iter(tokens))
|
||||
embed_count = 0
|
||||
qwen_tokens = tokens[key_name]
|
||||
for r in qwen_tokens:
|
||||
for i in range(len(r)):
|
||||
if r[i][0] == 151655:
|
||||
if len(images) > embed_count:
|
||||
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
|
||||
embed_count += 1
|
||||
return tokens
|
||||
|
||||
|
||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||
|
||||
428
comfy/text_encoders/qwen_vl.py
Normal file
428
comfy/text_encoders/qwen_vl.py
Normal file
@@ -0,0 +1,428 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple
|
||||
import math
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
|
||||
|
||||
def process_qwen2vl_images(
|
||||
images: torch.Tensor,
|
||||
min_pixels: int = 3136,
|
||||
max_pixels: int = 12845056,
|
||||
patch_size: int = 14,
|
||||
temporal_patch_size: int = 2,
|
||||
merge_size: int = 2,
|
||||
image_mean: list = None,
|
||||
image_std: list = None,
|
||||
):
|
||||
if image_mean is None:
|
||||
image_mean = [0.48145466, 0.4578275, 0.40821073]
|
||||
if image_std is None:
|
||||
image_std = [0.26862954, 0.26130258, 0.27577711]
|
||||
|
||||
batch_size, height, width, channels = images.shape
|
||||
device = images.device
|
||||
# dtype = images.dtype
|
||||
|
||||
images = images.permute(0, 3, 1, 2)
|
||||
|
||||
grid_thw_list = []
|
||||
img = images[0]
|
||||
|
||||
factor = patch_size * merge_size
|
||||
|
||||
h_bar = round(height / factor) * factor
|
||||
w_bar = round(width / factor) * factor
|
||||
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = max(factor, math.floor(height / beta / factor) * factor)
|
||||
w_bar = max(factor, math.floor(width / beta / factor) * factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = math.ceil(height * beta / factor) * factor
|
||||
w_bar = math.ceil(width * beta / factor) * factor
|
||||
|
||||
img_resized = F.interpolate(
|
||||
img.unsqueeze(0),
|
||||
size=(h_bar, w_bar),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
).squeeze(0)
|
||||
|
||||
normalized = img_resized.clone()
|
||||
for c in range(3):
|
||||
normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c]
|
||||
|
||||
grid_h = h_bar // patch_size
|
||||
grid_w = w_bar // patch_size
|
||||
grid_thw = torch.tensor([1, grid_h, grid_w], device=device, dtype=torch.long)
|
||||
|
||||
pixel_values = normalized
|
||||
grid_thw_list.append(grid_thw)
|
||||
image_grid_thw = torch.stack(grid_thw_list)
|
||||
|
||||
grid_t = 1
|
||||
channel = pixel_values.shape[0]
|
||||
pixel_values = pixel_values.unsqueeze(0).repeat(2, 1, 1, 1)
|
||||
|
||||
patches = pixel_values.reshape(
|
||||
grid_t,
|
||||
temporal_patch_size,
|
||||
channel,
|
||||
grid_h // merge_size,
|
||||
merge_size,
|
||||
patch_size,
|
||||
grid_w // merge_size,
|
||||
merge_size,
|
||||
patch_size,
|
||||
)
|
||||
|
||||
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
||||
flatten_patches = patches.reshape(
|
||||
grid_t * grid_h * grid_w,
|
||||
channel * temporal_patch_size * patch_size * patch_size
|
||||
)
|
||||
|
||||
return flatten_patches, image_grid_thw
|
||||
|
||||
|
||||
class VisionPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 14,
|
||||
temporal_patch_size: int = 2,
|
||||
in_channels: int = 3,
|
||||
embed_dim: int = 3584,
|
||||
device=None,
|
||||
dtype=None,
|
||||
ops=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.in_channels = in_channels
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
||||
self.proj = ops.Conv3d(
|
||||
in_channels,
|
||||
embed_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = hidden_states.view(
|
||||
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
||||
)
|
||||
hidden_states = self.proj(hidden_states)
|
||||
return hidden_states.view(-1, self.embed_dim)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(q, k, cos, sin):
|
||||
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class VisionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, theta: float = 10000.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
|
||||
def forward(self, seqlen: int, device) -> torch.Tensor:
|
||||
inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=device) / self.dim))
|
||||
seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype)
|
||||
freqs = torch.outer(seq, inv_freq)
|
||||
return freqs
|
||||
|
||||
|
||||
class PatchMerger(nn.Module):
|
||||
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.hidden_size = context_dim * (spatial_merge_size ** 2)
|
||||
self.ln_q = ops.RMSNorm(context_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
self.mlp = nn.Sequential(
|
||||
ops.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype),
|
||||
nn.GELU(),
|
||||
ops.Linear(self.hidden_size, dim, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.ln_q(x).reshape(-1, self.hidden_size)
|
||||
x = self.mlp(x)
|
||||
return x
|
||||
|
||||
|
||||
class VisionAttention(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = ops.Linear(hidden_size, hidden_size * 3, bias=True, device=device, dtype=dtype)
|
||||
self.proj = ops.Linear(hidden_size, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
cu_seqlens=None,
|
||||
optimized_attention=None,
|
||||
) -> torch.Tensor:
|
||||
if hidden_states.dim() == 2:
|
||||
seq_length, _ = hidden_states.shape
|
||||
batch_size = 1
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
else:
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
qkv = self.qkv(hidden_states)
|
||||
qkv = qkv.reshape(batch_size, seq_length, 3, self.num_heads, self.head_dim)
|
||||
query_states, key_states, value_states = qkv.reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
||||
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
|
||||
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
splits = [
|
||||
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
|
||||
]
|
||||
|
||||
attn_outputs = [
|
||||
optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
|
||||
for q, k, v in zip(*splits)
|
||||
]
|
||||
attn_output = torch.cat(attn_outputs, dim=1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
attn_output = self.proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class VisionMLP(nn.Module):
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.gate_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
|
||||
self.up_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
|
||||
self.down_proj = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, hidden_state):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
||||
|
||||
|
||||
class VisionBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, num_heads: int, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.norm1 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.norm2 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.attn = VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
cu_seqlens=None,
|
||||
optimized_attention=None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.attn(hidden_states, position_embeddings, cu_seqlens, optimized_attention)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen2VLVisionTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 3584,
|
||||
output_hidden_size: int = 3584,
|
||||
intermediate_size: int = 3420,
|
||||
num_heads: int = 16,
|
||||
num_layers: int = 32,
|
||||
patch_size: int = 14,
|
||||
temporal_patch_size: int = 2,
|
||||
spatial_merge_size: int = 2,
|
||||
window_size: int = 112,
|
||||
device=None,
|
||||
dtype=None,
|
||||
ops=None
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.patch_size = patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.window_size = window_size
|
||||
self.fullatt_block_indexes = [7, 15, 23, 31]
|
||||
|
||||
self.patch_embed = VisionPatchEmbed(
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
in_channels=3,
|
||||
embed_dim=hidden_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
ops=ops,
|
||||
)
|
||||
|
||||
head_dim = hidden_size // num_heads
|
||||
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
VisionBlock(hidden_size, intermediate_size, num_heads, device, dtype, ops)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.merger = PatchMerger(
|
||||
dim=output_hidden_size,
|
||||
context_dim=hidden_size,
|
||||
spatial_merge_size=spatial_merge_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
ops=ops,
|
||||
)
|
||||
|
||||
def get_window_index(self, grid_thw):
|
||||
window_index = []
|
||||
cu_window_seqlens = [0]
|
||||
window_index_id = 0
|
||||
vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
|
||||
|
||||
for grid_t, grid_h, grid_w in grid_thw:
|
||||
llm_grid_h = grid_h // self.spatial_merge_size
|
||||
llm_grid_w = grid_w // self.spatial_merge_size
|
||||
|
||||
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
|
||||
|
||||
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
||||
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
||||
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
||||
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
||||
|
||||
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
||||
index_padded = index_padded.reshape(
|
||||
grid_t,
|
||||
num_windows_h,
|
||||
vit_merger_window_size,
|
||||
num_windows_w,
|
||||
vit_merger_window_size,
|
||||
)
|
||||
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
||||
grid_t,
|
||||
num_windows_h * num_windows_w,
|
||||
vit_merger_window_size,
|
||||
vit_merger_window_size,
|
||||
)
|
||||
|
||||
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
||||
index_padded = index_padded.reshape(-1)
|
||||
index_new = index_padded[index_padded != -100]
|
||||
window_index.append(index_new + window_index_id)
|
||||
|
||||
cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_size * self.spatial_merge_size + cu_window_seqlens[-1]
|
||||
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
||||
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
||||
|
||||
window_index = torch.cat(window_index, dim=0)
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def get_position_embeddings(self, grid_thw, device):
|
||||
pos_ids = []
|
||||
|
||||
for t, h, w in grid_thw:
|
||||
hpos_ids = torch.arange(h, device=device).unsqueeze(1).expand(-1, w)
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
hpos_ids = hpos_ids.permute(0, 2, 1, 3).flatten()
|
||||
|
||||
wpos_ids = torch.arange(w, device=device).unsqueeze(0).expand(h, -1)
|
||||
wpos_ids = wpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
wpos_ids = wpos_ids.permute(0, 2, 1, 3).flatten()
|
||||
|
||||
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
max_grid_size = grid_thw[:, 1:].max()
|
||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device)
|
||||
return rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
image_grid_thw: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
optimized_attention = optimized_attention_for_device(pixel_values.device, mask=False, small_input=True)
|
||||
|
||||
hidden_states = self.patch_embed(pixel_values)
|
||||
|
||||
window_index, cu_window_seqlens = self.get_window_index(image_grid_thw)
|
||||
cu_window_seqlens = torch.tensor(cu_window_seqlens, device=hidden_states.device)
|
||||
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
||||
|
||||
position_embeddings = self.get_position_embeddings(image_grid_thw, hidden_states.device)
|
||||
|
||||
seq_len, _ = hidden_states.size()
|
||||
spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
|
||||
|
||||
hidden_states = hidden_states.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1)
|
||||
hidden_states = hidden_states[window_index, :, :]
|
||||
hidden_states = hidden_states.reshape(seq_len, -1)
|
||||
|
||||
position_embeddings = position_embeddings.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1)
|
||||
position_embeddings = position_embeddings[window_index, :, :]
|
||||
position_embeddings = position_embeddings.reshape(seq_len, -1)
|
||||
position_embeddings = torch.cat((position_embeddings, position_embeddings), dim=-1)
|
||||
position_embeddings = (position_embeddings.cos(), position_embeddings.sin())
|
||||
|
||||
cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum(
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
if i in self.fullatt_block_indexes:
|
||||
cu_seqlens_now = cu_seqlens
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention)
|
||||
|
||||
hidden_states = self.merger(hidden_states)
|
||||
return hidden_states
|
||||
@@ -199,7 +199,7 @@ class T5Stack(torch.nn.Module):
|
||||
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
|
||||
@@ -726,6 +726,10 @@ class SEGS(ComfyTypeIO):
|
||||
class AnyType(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="MODEL_PATCH")
|
||||
class MODEL_PATCH(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
||||
class MultiType:
|
||||
Type = Any
|
||||
|
||||
16
comfy_api_nodes/apis/__init__.py
generated
16
comfy_api_nodes/apis/__init__.py
generated
@@ -1315,6 +1315,7 @@ class KlingTaskStatus(str, Enum):
|
||||
class KlingTextToVideoModelName(str, Enum):
|
||||
kling_v1 = 'kling-v1'
|
||||
kling_v1_6 = 'kling-v1-6'
|
||||
kling_v2_1_master = 'kling-v2-1-master'
|
||||
|
||||
|
||||
class KlingVideoGenAspectRatio(str, Enum):
|
||||
@@ -1347,6 +1348,8 @@ class KlingVideoGenModelName(str, Enum):
|
||||
kling_v1_5 = 'kling-v1-5'
|
||||
kling_v1_6 = 'kling-v1-6'
|
||||
kling_v2_master = 'kling-v2-master'
|
||||
kling_v2_1 = 'kling-v2-1'
|
||||
kling_v2_1_master = 'kling-v2-1-master'
|
||||
|
||||
|
||||
class KlingVideoResult(BaseModel):
|
||||
@@ -1620,13 +1623,14 @@ class MinimaxTaskResultResponse(BaseModel):
|
||||
task_id: str = Field(..., description='The task ID being queried.')
|
||||
|
||||
|
||||
class Model(str, Enum):
|
||||
class MiniMaxModel(str, Enum):
|
||||
T2V_01_Director = 'T2V-01-Director'
|
||||
I2V_01_Director = 'I2V-01-Director'
|
||||
S2V_01 = 'S2V-01'
|
||||
I2V_01 = 'I2V-01'
|
||||
I2V_01_live = 'I2V-01-live'
|
||||
T2V_01 = 'T2V-01'
|
||||
Hailuo_02 = 'MiniMax-Hailuo-02'
|
||||
|
||||
|
||||
class SubjectReferenceItem(BaseModel):
|
||||
@@ -1648,7 +1652,7 @@ class MinimaxVideoGenerationRequest(BaseModel):
|
||||
None,
|
||||
description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.',
|
||||
)
|
||||
model: Model = Field(
|
||||
model: MiniMaxModel = Field(
|
||||
...,
|
||||
description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01',
|
||||
)
|
||||
@@ -1665,6 +1669,14 @@ class MinimaxVideoGenerationRequest(BaseModel):
|
||||
None,
|
||||
description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.',
|
||||
)
|
||||
duration: Optional[int] = Field(
|
||||
None,
|
||||
description="The length of the output video in seconds."
|
||||
)
|
||||
resolution: Optional[str] = Field(
|
||||
None,
|
||||
description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels."
|
||||
)
|
||||
|
||||
|
||||
class MinimaxVideoGenerationResponse(BaseModel):
|
||||
|
||||
@@ -5,7 +5,10 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Optional, Literal
|
||||
|
||||
@@ -46,6 +49,8 @@ class GeminiModel(str, Enum):
|
||||
|
||||
gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06"
|
||||
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
|
||||
gemini_2_5_pro = "gemini-2.5-pro"
|
||||
gemini_2_5_flash = "gemini-2.5-flash"
|
||||
|
||||
|
||||
def get_gemini_endpoint(
|
||||
@@ -97,7 +102,7 @@ class GeminiNode(ComfyNodeABC):
|
||||
{
|
||||
"tooltip": "The Gemini model to use for generating responses.",
|
||||
"options": [model.value for model in GeminiModel],
|
||||
"default": GeminiModel.gemini_2_5_pro_preview_05_06.value,
|
||||
"default": GeminiModel.gemini_2_5_pro.value,
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
@@ -348,7 +353,27 @@ class GeminiNode(ComfyNodeABC):
|
||||
# Get result output
|
||||
output_text = self.get_text_from_response(response)
|
||||
if unique_id and output_text:
|
||||
PromptServer.instance.send_progress_text(output_text, node_id=unique_id)
|
||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||
render_spec = {
|
||||
"node_id": unique_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": json.dumps(
|
||||
[
|
||||
{
|
||||
"prompt": prompt,
|
||||
"response": output_text,
|
||||
"response_id": str(uuid.uuid4()),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
]
|
||||
),
|
||||
},
|
||||
}
|
||||
PromptServer.instance.send_sync(
|
||||
"display_component",
|
||||
render_spec,
|
||||
)
|
||||
|
||||
return (output_text or "Empty response from Gemini model...",)
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||
from inspect import cleandoc
|
||||
from io import BytesIO
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import io
|
||||
import torch
|
||||
from comfy_api_nodes.apis import (
|
||||
IdeogramGenerateRequest,
|
||||
@@ -246,90 +246,81 @@ def display_image_urls_on_node(image_urls, node_id):
|
||||
PromptServer.instance.send_progress_text(urls_text, node_id)
|
||||
|
||||
|
||||
class IdeogramV1(ComfyNodeABC):
|
||||
"""
|
||||
Generates images using the Ideogram V1 model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
class IdeogramV1(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="IdeogramV1",
|
||||
display_name="Ideogram V1",
|
||||
category="api node/image/Ideogram",
|
||||
description="Generates images using the Ideogram V1 model.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
"turbo": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
}
|
||||
comfy_io.Boolean.Input(
|
||||
"turbo",
|
||||
default=False,
|
||||
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V1_V2_RATIO_MAP.keys()),
|
||||
"default": "1:1",
|
||||
"tooltip": "The aspect ratio for image generation.",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=list(V1_V2_RATIO_MAP.keys()),
|
||||
default="1:1",
|
||||
tooltip="The aspect ratio for image generation.",
|
||||
optional=True,
|
||||
),
|
||||
"magic_prompt_option": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "ON", "OFF"],
|
||||
"default": "AUTO",
|
||||
"tooltip": "Determine if MagicPrompt should be used in generation",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"magic_prompt_option",
|
||||
options=["AUTO", "ON", "OFF"],
|
||||
default="AUTO",
|
||||
tooltip="Determine if MagicPrompt should be used in generation",
|
||||
optional=True,
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"step": 1,
|
||||
"control_after_generate": True,
|
||||
"display": "number",
|
||||
},
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Description of what to exclude from the image",
|
||||
},
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Description of what to exclude from the image",
|
||||
optional=True,
|
||||
),
|
||||
"num_images": (
|
||||
IO.INT,
|
||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||
comfy_io.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt,
|
||||
turbo=False,
|
||||
aspect_ratio="1:1",
|
||||
@@ -337,13 +328,15 @@ class IdeogramV1(ComfyNodeABC):
|
||||
seed=0,
|
||||
negative_prompt="",
|
||||
num_images=1,
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Determine the model based on turbo setting
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
model = "V_1_TURBO" if turbo else "V_1"
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/ideogram/generate",
|
||||
@@ -364,7 +357,7 @@ class IdeogramV1(ComfyNodeABC):
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
)
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
response = await operation.execute()
|
||||
@@ -377,93 +370,85 @@ class IdeogramV1(ComfyNodeABC):
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (await download_and_process_images(image_urls),)
|
||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
class IdeogramV2(ComfyNodeABC):
|
||||
"""
|
||||
Generates images using the Ideogram V2 model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
class IdeogramV2(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="IdeogramV2",
|
||||
display_name="Ideogram V2",
|
||||
category="api node/image/Ideogram",
|
||||
description="Generates images using the Ideogram V2 model.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
"turbo": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
}
|
||||
comfy_io.Boolean.Input(
|
||||
"turbo",
|
||||
default=False,
|
||||
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V1_V2_RATIO_MAP.keys()),
|
||||
"default": "1:1",
|
||||
"tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=list(V1_V2_RATIO_MAP.keys()),
|
||||
default="1:1",
|
||||
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
|
||||
optional=True,
|
||||
),
|
||||
"resolution": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V1_V1_RES_MAP.keys()),
|
||||
"default": "Auto",
|
||||
"tooltip": "The resolution for image generation. If not set to AUTO, this overrides the aspect_ratio setting.",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"resolution",
|
||||
options=list(V1_V1_RES_MAP.keys()),
|
||||
default="Auto",
|
||||
tooltip="The resolution for image generation. "
|
||||
"If not set to AUTO, this overrides the aspect_ratio setting.",
|
||||
optional=True,
|
||||
),
|
||||
"magic_prompt_option": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "ON", "OFF"],
|
||||
"default": "AUTO",
|
||||
"tooltip": "Determine if MagicPrompt should be used in generation",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"magic_prompt_option",
|
||||
options=["AUTO", "ON", "OFF"],
|
||||
default="AUTO",
|
||||
tooltip="Determine if MagicPrompt should be used in generation",
|
||||
optional=True,
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"step": 1,
|
||||
"control_after_generate": True,
|
||||
"display": "number",
|
||||
},
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
"style_type": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
|
||||
"default": "NONE",
|
||||
"tooltip": "Style type for generation (V2 only)",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"style_type",
|
||||
options=["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
|
||||
default="NONE",
|
||||
tooltip="Style type for generation (V2 only)",
|
||||
optional=True,
|
||||
),
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Description of what to exclude from the image",
|
||||
},
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Description of what to exclude from the image",
|
||||
optional=True,
|
||||
),
|
||||
"num_images": (
|
||||
IO.INT,
|
||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||
comfy_io.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
#"color_palette": (
|
||||
# IO.STRING,
|
||||
@@ -473,22 +458,20 @@ class IdeogramV2(ComfyNodeABC):
|
||||
# "tooltip": "Color palette preset name or hex colors with weights",
|
||||
# },
|
||||
#),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt,
|
||||
turbo=False,
|
||||
aspect_ratio="1:1",
|
||||
@@ -499,8 +482,6 @@ class IdeogramV2(ComfyNodeABC):
|
||||
negative_prompt="",
|
||||
num_images=1,
|
||||
color_palette="",
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
resolution = V1_V1_RES_MAP.get(resolution, None)
|
||||
@@ -517,6 +498,10 @@ class IdeogramV2(ComfyNodeABC):
|
||||
else:
|
||||
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/ideogram/generate",
|
||||
@@ -540,7 +525,7 @@ class IdeogramV2(ComfyNodeABC):
|
||||
color_palette=color_palette if color_palette else None,
|
||||
)
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
response = await operation.execute()
|
||||
@@ -553,108 +538,99 @@ class IdeogramV2(ComfyNodeABC):
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (await download_and_process_images(image_urls),)
|
||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
class IdeogramV3(ComfyNodeABC):
|
||||
"""
|
||||
Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
class IdeogramV3(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation or editing",
|
||||
},
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="IdeogramV3",
|
||||
display_name="Ideogram V3",
|
||||
category="api node/image/Ideogram",
|
||||
description="Generates images using the Ideogram V3 model. "
|
||||
"Supports both regular image generation from text prompts and image editing with mask.",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation or editing",
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image for image editing.",
|
||||
},
|
||||
comfy_io.Image.Input(
|
||||
"image",
|
||||
tooltip="Optional reference image for image editing.",
|
||||
optional=True,
|
||||
),
|
||||
"mask": (
|
||||
IO.MASK,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
||||
},
|
||||
comfy_io.Mask.Input(
|
||||
"mask",
|
||||
tooltip="Optional mask for inpainting (white areas will be replaced)",
|
||||
optional=True,
|
||||
),
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V3_RATIO_MAP.keys()),
|
||||
"default": "1:1",
|
||||
"tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=list(V3_RATIO_MAP.keys()),
|
||||
default="1:1",
|
||||
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
|
||||
optional=True,
|
||||
),
|
||||
"resolution": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": V3_RESOLUTIONS,
|
||||
"default": "Auto",
|
||||
"tooltip": "The resolution for image generation. If not set to Auto, this overrides the aspect_ratio setting.",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"resolution",
|
||||
options=V3_RESOLUTIONS,
|
||||
default="Auto",
|
||||
tooltip="The resolution for image generation. "
|
||||
"If not set to Auto, this overrides the aspect_ratio setting.",
|
||||
optional=True,
|
||||
),
|
||||
"magic_prompt_option": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "ON", "OFF"],
|
||||
"default": "AUTO",
|
||||
"tooltip": "Determine if MagicPrompt should be used in generation",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"magic_prompt_option",
|
||||
options=["AUTO", "ON", "OFF"],
|
||||
default="AUTO",
|
||||
tooltip="Determine if MagicPrompt should be used in generation",
|
||||
optional=True,
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"step": 1,
|
||||
"control_after_generate": True,
|
||||
"display": "number",
|
||||
},
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
"num_images": (
|
||||
IO.INT,
|
||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||
comfy_io.Int.Input(
|
||||
"num_images",
|
||||
default=1,
|
||||
min=1,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
optional=True,
|
||||
),
|
||||
"rendering_speed": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["BALANCED", "TURBO", "QUALITY"],
|
||||
"default": "BALANCED",
|
||||
"tooltip": "Controls the trade-off between generation speed and quality",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"rendering_speed",
|
||||
options=["BALANCED", "TURBO", "QUALITY"],
|
||||
default="BALANCED",
|
||||
tooltip="Controls the trade-off between generation speed and quality",
|
||||
optional=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt,
|
||||
image=None,
|
||||
mask=None,
|
||||
@@ -664,9 +640,11 @@ class IdeogramV3(ComfyNodeABC):
|
||||
seed=0,
|
||||
num_images=1,
|
||||
rendering_speed="BALANCED",
|
||||
unique_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
# Check if both image and mask are provided for editing mode
|
||||
if image is not None and mask is not None:
|
||||
# Edit mode
|
||||
@@ -686,7 +664,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
# Process image
|
||||
img_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(img_np)
|
||||
img_byte_arr = io.BytesIO()
|
||||
img_byte_arr = BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr.seek(0)
|
||||
img_binary = img_byte_arr
|
||||
@@ -695,7 +673,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
# Process mask - white areas will be replaced
|
||||
mask_np = (mask.squeeze().cpu().numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
mask_byte_arr = io.BytesIO()
|
||||
mask_byte_arr = BytesIO()
|
||||
mask_img.save(mask_byte_arr, format="PNG")
|
||||
mask_byte_arr.seek(0)
|
||||
mask_binary = mask_byte_arr
|
||||
@@ -729,7 +707,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
"mask": mask_binary,
|
||||
},
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
elif image is not None or mask is not None:
|
||||
@@ -770,7 +748,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=gen_request,
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
# Execute the operation and process response
|
||||
@@ -784,18 +762,18 @@ class IdeogramV3(ComfyNodeABC):
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, unique_id)
|
||||
return (await download_and_process_images(image_urls),)
|
||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"IdeogramV1": IdeogramV1,
|
||||
"IdeogramV2": IdeogramV2,
|
||||
"IdeogramV3": IdeogramV3,
|
||||
}
|
||||
class IdeogramExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
return [
|
||||
IdeogramV1,
|
||||
IdeogramV2,
|
||||
IdeogramV3,
|
||||
]
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"IdeogramV1": "Ideogram V1",
|
||||
"IdeogramV2": "Ideogram V2",
|
||||
"IdeogramV3": "Ideogram V3",
|
||||
}
|
||||
async def comfy_entrypoint() -> IdeogramExtension:
|
||||
return IdeogramExtension()
|
||||
|
||||
@@ -421,6 +421,8 @@ class KlingTextToVideoNode(KlingNodeBase):
|
||||
"pro mode / 10s duration / kling-v2-master": ("pro", "10", "kling-v2-master"),
|
||||
"standard mode / 5s duration / kling-v2-master": ("std", "5", "kling-v2-master"),
|
||||
"standard mode / 10s duration / kling-v2-master": ("std", "10", "kling-v2-master"),
|
||||
"pro mode / 5s duration / kling-v2-1-master": ("pro", "5", "kling-v2-1-master"),
|
||||
"pro mode / 10s duration / kling-v2-1-master": ("pro", "10", "kling-v2-1-master"),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from inspect import cleandoc
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
@@ -10,7 +11,7 @@ from comfy_api_nodes.apis import (
|
||||
MinimaxFileRetrieveResponse,
|
||||
MinimaxTaskResultResponse,
|
||||
SubjectReferenceItem,
|
||||
Model
|
||||
MiniMaxModel
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
@@ -84,7 +85,6 @@ class MinimaxTextToVideoNode:
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/MiniMax"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
async def generate_video(
|
||||
self,
|
||||
@@ -121,7 +121,7 @@ class MinimaxTextToVideoNode:
|
||||
response_model=MinimaxVideoGenerationResponse,
|
||||
),
|
||||
request=MinimaxVideoGenerationRequest(
|
||||
model=Model(model),
|
||||
model=MiniMaxModel(model),
|
||||
prompt=prompt_text,
|
||||
callback_url=None,
|
||||
first_frame_image=image_url,
|
||||
@@ -251,7 +251,6 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/MiniMax"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
|
||||
class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
|
||||
@@ -313,7 +312,181 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/MiniMax"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
|
||||
class MinimaxHailuoVideoNode:
|
||||
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt_text": (
|
||||
"STRING",
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt to guide the video generation.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
"first_frame_image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"tooltip": "Optional image to use as the first frame to generate a video."
|
||||
},
|
||||
),
|
||||
"prompt_optimizer": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"tooltip": "Optimize prompt to improve generation quality when needed.",
|
||||
"default": True,
|
||||
},
|
||||
),
|
||||
"duration": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "The length of the output video in seconds.",
|
||||
"default": 6,
|
||||
"options": [6, 10],
|
||||
},
|
||||
),
|
||||
"resolution": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "The dimensions of the video display. "
|
||||
"1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels.",
|
||||
"default": "768P",
|
||||
"options": ["768P", "1080P"],
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/MiniMax"
|
||||
API_NODE = True
|
||||
|
||||
async def generate_video(
|
||||
self,
|
||||
prompt_text,
|
||||
seed=0,
|
||||
first_frame_image: torch.Tensor=None, # used for ImageToVideo
|
||||
prompt_optimizer=True,
|
||||
duration=6,
|
||||
resolution="768P",
|
||||
model="MiniMax-Hailuo-02",
|
||||
unique_id: Union[str, None]=None,
|
||||
**kwargs,
|
||||
):
|
||||
if first_frame_image is None:
|
||||
validate_string(prompt_text, field_name="prompt_text")
|
||||
|
||||
if model == "MiniMax-Hailuo-02" and resolution.upper() == "1080P" and duration != 6:
|
||||
raise Exception(
|
||||
"When model is MiniMax-Hailuo-02 and resolution is 1080P, duration is limited to 6 seconds."
|
||||
)
|
||||
|
||||
# upload image, if passed in
|
||||
image_url = None
|
||||
if first_frame_image is not None:
|
||||
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=kwargs))[0]
|
||||
|
||||
video_generate_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/video_generation",
|
||||
method=HttpMethod.POST,
|
||||
request_model=MinimaxVideoGenerationRequest,
|
||||
response_model=MinimaxVideoGenerationResponse,
|
||||
),
|
||||
request=MinimaxVideoGenerationRequest(
|
||||
model=MiniMaxModel(model),
|
||||
prompt=prompt_text,
|
||||
callback_url=None,
|
||||
first_frame_image=image_url,
|
||||
prompt_optimizer=prompt_optimizer,
|
||||
duration=duration,
|
||||
resolution=resolution,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response = await video_generate_operation.execute()
|
||||
|
||||
task_id = response.task_id
|
||||
if not task_id:
|
||||
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
||||
|
||||
average_duration = 120 if resolution == "768P" else 240
|
||||
video_generate_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/query/video_generation",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxTaskResultResponse,
|
||||
query_params={"task_id": task_id},
|
||||
),
|
||||
completed_statuses=["Success"],
|
||||
failed_statuses=["Fail"],
|
||||
status_extractor=lambda x: x.status.value,
|
||||
estimated_duration=average_duration,
|
||||
node_id=unique_id,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_result = await video_generate_operation.execute()
|
||||
|
||||
file_id = task_result.file_id
|
||||
if file_id is None:
|
||||
raise Exception("Request was not successful. Missing file ID.")
|
||||
file_retrieve_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/files/retrieve",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxFileRetrieveResponse,
|
||||
query_params={"file_id": int(file_id)},
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
file_result = await file_retrieve_operation.execute()
|
||||
|
||||
file_url = file_result.file.download_url
|
||||
if file_url is None:
|
||||
raise Exception(
|
||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
||||
)
|
||||
logging.info(f"Generated video URL: {file_url}")
|
||||
if unique_id:
|
||||
if hasattr(file_result.file, "backup_download_url"):
|
||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
||||
else:
|
||||
message = f"Result URL: {file_url}"
|
||||
PromptServer.instance.send_progress_text(message, unique_id)
|
||||
|
||||
video_io = await download_url_to_bytesio(file_url)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {file_url}"
|
||||
logging.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
return (VideoFromFile(video_io),)
|
||||
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
@@ -322,6 +495,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"MinimaxTextToVideoNode": MinimaxTextToVideoNode,
|
||||
"MinimaxImageToVideoNode": MinimaxImageToVideoNode,
|
||||
# "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode,
|
||||
"MinimaxHailuoVideoNode": MinimaxHailuoVideoNode,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
@@ -329,4 +503,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"MinimaxTextToVideoNode": "MiniMax Text to Video",
|
||||
"MinimaxImageToVideoNode": "MiniMax Image to Video",
|
||||
"MinimaxSubjectToVideoNode": "MiniMax Subject to Video",
|
||||
"MinimaxHailuoVideoNode": "MiniMax Hailuo Video",
|
||||
}
|
||||
|
||||
@@ -80,6 +80,9 @@ class SupportedOpenAIModel(str, Enum):
|
||||
gpt_4_1 = "gpt-4.1"
|
||||
gpt_4_1_mini = "gpt-4.1-mini"
|
||||
gpt_4_1_nano = "gpt-4.1-nano"
|
||||
gpt_5 = "gpt-5"
|
||||
gpt_5_mini = "gpt-5-mini"
|
||||
gpt_5_nano = "gpt-5-nano"
|
||||
|
||||
|
||||
class OpenAIDalle2(ComfyNodeABC):
|
||||
@@ -995,7 +998,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"OpenAIDalle2": "OpenAI DALL·E 2",
|
||||
"OpenAIDalle3": "OpenAI DALL·E 3",
|
||||
"OpenAIGPTImage1": "OpenAI GPT Image 1",
|
||||
"OpenAIChatNode": "OpenAI Chat",
|
||||
"OpenAIInputFiles": "OpenAI Chat Input Files",
|
||||
"OpenAIChatConfig": "OpenAI Chat Advanced Options",
|
||||
"OpenAIChatNode": "OpenAI ChatGPT",
|
||||
"OpenAIInputFiles": "OpenAI ChatGPT Input Files",
|
||||
"OpenAIChatConfig": "OpenAI ChatGPT Advanced Options",
|
||||
}
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import io
|
||||
import logging
|
||||
import base64
|
||||
import aiohttp
|
||||
import torch
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis import (
|
||||
VeoGenVidRequest,
|
||||
VeoGenVidResponse,
|
||||
VeoGenVidPollRequest,
|
||||
VeoGenVidPollResponse
|
||||
VeoGenVidPollResponse,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
@@ -22,7 +23,7 @@ from comfy_api_nodes.apis.client import (
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
downscale_image_tensor,
|
||||
tensor_to_base64_string
|
||||
tensor_to_base64_string,
|
||||
)
|
||||
|
||||
AVERAGE_DURATION_VIDEO_GEN = 32
|
||||
@@ -50,7 +51,7 @@ def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optiona
|
||||
return None
|
||||
|
||||
|
||||
class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
class VeoVideoGenerationNode(comfy_io.ComfyNode):
|
||||
"""
|
||||
Generates videos from text prompts using Google's Veo API.
|
||||
|
||||
@@ -59,101 +60,93 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text description of the video",
|
||||
},
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="VeoVideoGenerationNode",
|
||||
display_name="Google Veo 2 Video Generation",
|
||||
category="api node/video/Veo",
|
||||
description="Generates videos from text prompts using Google's Veo 2 API",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the video",
|
||||
),
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["16:9", "9:16"],
|
||||
"default": "16:9",
|
||||
"tooltip": "Aspect ratio of the output video",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16"],
|
||||
default="16:9",
|
||||
tooltip="Aspect ratio of the output video",
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Negative text prompt to guide what to avoid in the video",
|
||||
},
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid in the video",
|
||||
optional=True,
|
||||
),
|
||||
"duration_seconds": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 5,
|
||||
"min": 5,
|
||||
"max": 8,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "Duration of the output video in seconds",
|
||||
},
|
||||
comfy_io.Int.Input(
|
||||
"duration_seconds",
|
||||
default=5,
|
||||
min=5,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds",
|
||||
optional=True,
|
||||
),
|
||||
"enhance_prompt": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": True,
|
||||
"tooltip": "Whether to enhance the prompt with AI assistance",
|
||||
}
|
||||
comfy_io.Boolean.Input(
|
||||
"enhance_prompt",
|
||||
default=True,
|
||||
tooltip="Whether to enhance the prompt with AI assistance",
|
||||
optional=True,
|
||||
),
|
||||
"person_generation": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["ALLOW", "BLOCK"],
|
||||
"default": "ALLOW",
|
||||
"tooltip": "Whether to allow generating people in the video",
|
||||
},
|
||||
comfy_io.Combo.Input(
|
||||
"person_generation",
|
||||
options=["ALLOW", "BLOCK"],
|
||||
default="ALLOW",
|
||||
tooltip="Whether to allow generating people in the video",
|
||||
optional=True,
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFF,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed for video generation (0 for random)",
|
||||
},
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
"image": (IO.IMAGE, {
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image to guide video generation",
|
||||
}),
|
||||
"model": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["veo-2.0-generate-001"],
|
||||
"default": "veo-2.0-generate-001",
|
||||
"tooltip": "Veo 2 model to use for video generation",
|
||||
},
|
||||
comfy_io.Image.Input(
|
||||
"image",
|
||||
tooltip="Optional reference image to guide video generation",
|
||||
optional=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=["veo-2.0-generate-001"],
|
||||
default="veo-2.0-generate-001",
|
||||
tooltip="Veo 2 model to use for video generation",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/Veo"
|
||||
DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API"
|
||||
API_NODE = True
|
||||
|
||||
async def generate_video(
|
||||
self,
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt,
|
||||
aspect_ratio="16:9",
|
||||
negative_prompt="",
|
||||
@@ -164,8 +157,6 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
image=None,
|
||||
model="veo-2.0-generate-001",
|
||||
generate_audio=False,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Prepare the instances for the request
|
||||
instances = []
|
||||
@@ -202,6 +193,10 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
if "veo-3.0" in model:
|
||||
parameters["generateAudio"] = generate_audio
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
# Initial request to start video generation
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
@@ -214,7 +209,7 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
instances=instances,
|
||||
parameters=parameters
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
initial_response = await initial_operation.execute()
|
||||
@@ -248,10 +243,10 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
request=VeoGenVidPollRequest(
|
||||
operationName=operation_name
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
auth_kwargs=auth,
|
||||
poll_interval=5.0,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
node_id=unique_id,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||
)
|
||||
|
||||
@@ -304,10 +299,10 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
logging.info("Video generation completed successfully")
|
||||
|
||||
# Convert video data to BytesIO object
|
||||
video_io = io.BytesIO(video_data)
|
||||
video_io = BytesIO(video_data)
|
||||
|
||||
# Return VideoFromFile object
|
||||
return (VideoFromFile(video_io),)
|
||||
return comfy_io.NodeOutput(VideoFromFile(video_io))
|
||||
|
||||
|
||||
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
@@ -323,51 +318,104 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
parent_input = super().INPUT_TYPES()
|
||||
|
||||
# Update model options for Veo 3
|
||||
parent_input["optional"]["model"] = (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["veo-3.0-generate-001", "veo-3.0-fast-generate-001"],
|
||||
"default": "veo-3.0-generate-001",
|
||||
"tooltip": "Veo 3 model to use for video generation",
|
||||
},
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="Veo3VideoGenerationNode",
|
||||
display_name="Google Veo 3 Video Generation",
|
||||
category="api node/video/Veo",
|
||||
description="Generates videos from text prompts using Google's Veo 3 API",
|
||||
inputs=[
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the video",
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16"],
|
||||
default="16:9",
|
||||
tooltip="Aspect ratio of the output video",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid in the video",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"duration_seconds",
|
||||
default=8,
|
||||
min=8,
|
||||
max=8,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
"enhance_prompt",
|
||||
default=True,
|
||||
tooltip="Whether to enhance the prompt with AI assistance",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"person_generation",
|
||||
options=["ALLOW", "BLOCK"],
|
||||
default="ALLOW",
|
||||
tooltip="Whether to allow generating people in the video",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"image",
|
||||
tooltip="Optional reference image to guide video generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=["veo-3.0-generate-001", "veo-3.0-fast-generate-001"],
|
||||
default="veo-3.0-generate-001",
|
||||
tooltip="Veo 3 model to use for video generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=False,
|
||||
tooltip="Generate audio for the video. Supported by all Veo 3 models.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
# Add generateAudio parameter
|
||||
parent_input["optional"]["generate_audio"] = (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Generate audio for the video. Supported by all Veo 3 models.",
|
||||
}
|
||||
)
|
||||
|
||||
# Update duration constraints for Veo 3 (only 8 seconds supported)
|
||||
parent_input["optional"]["duration_seconds"] = (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 8,
|
||||
"min": 8,
|
||||
"max": 8,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
|
||||
},
|
||||
)
|
||||
class VeoExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
return [
|
||||
VeoVideoGenerationNode,
|
||||
Veo3VideoGenerationNode,
|
||||
]
|
||||
|
||||
return parent_input
|
||||
|
||||
|
||||
# Register the nodes
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"VeoVideoGenerationNode": VeoVideoGenerationNode,
|
||||
"Veo3VideoGenerationNode": Veo3VideoGenerationNode,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"VeoVideoGenerationNode": "Google Veo 2 Video Generation",
|
||||
"Veo3VideoGenerationNode": "Google Veo 3 Video Generation",
|
||||
}
|
||||
async def comfy_entrypoint() -> VeoExtension:
|
||||
return VeoExtension()
|
||||
|
||||
622
comfy_api_nodes/nodes_vidu.py
Normal file
622
comfy_api_nodes/nodes_vidu.py
Normal file
@@ -0,0 +1,622 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Literal, TypeVar
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||
from comfy_api_nodes.util.validation_utils import (
|
||||
validate_aspect_ratio_closeness,
|
||||
validate_image_dimensions,
|
||||
validate_image_aspect_ratio_range,
|
||||
get_number_of_images,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import download_url_to_video_output, upload_images_to_comfyapi
|
||||
|
||||
|
||||
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
|
||||
VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video"
|
||||
VIDU_REFERENCE_VIDEO = "/proxy/vidu/reference2video"
|
||||
VIDU_START_END_VIDEO = "/proxy/vidu/start-end2video"
|
||||
VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations"
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
class VideoModelName(str, Enum):
|
||||
vidu_q1 = 'viduq1'
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
r_16_9 = "16:9"
|
||||
r_9_16 = "9:16"
|
||||
r_1_1 = "1:1"
|
||||
|
||||
|
||||
class Resolution(str, Enum):
|
||||
r_1080p = "1080p"
|
||||
|
||||
|
||||
class MovementAmplitude(str, Enum):
|
||||
auto = "auto"
|
||||
small = "small"
|
||||
medium = "medium"
|
||||
large = "large"
|
||||
|
||||
|
||||
class TaskCreationRequest(BaseModel):
|
||||
model: VideoModelName = VideoModelName.vidu_q1
|
||||
prompt: Optional[str] = Field(None, max_length=1500)
|
||||
duration: Optional[Literal[5]] = 5
|
||||
seed: Optional[int] = Field(0, ge=0, le=2147483647)
|
||||
aspect_ratio: Optional[AspectRatio] = AspectRatio.r_16_9
|
||||
resolution: Optional[Resolution] = Resolution.r_1080p
|
||||
movement_amplitude: Optional[MovementAmplitude] = MovementAmplitude.auto
|
||||
images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL")
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
created = "created"
|
||||
queueing = "queueing"
|
||||
processing = "processing"
|
||||
success = "success"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
task_id: str = Field(...)
|
||||
state: TaskStatus = Field(...)
|
||||
created_at: str = Field(...)
|
||||
code: Optional[int] = Field(None, description="Error code")
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
id: str = Field(..., description="Creation id")
|
||||
url: str = Field(..., description="The URL of the generated results, valid for one hour")
|
||||
cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour")
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
state: TaskStatus = Field(...)
|
||||
err_code: Optional[str] = Field(None)
|
||||
creations: list[TaskResult] = Field(..., description="Generated results")
|
||||
|
||||
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
estimated_duration: Optional[int] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[TaskStatus.success.value],
|
||||
failed_statuses=[TaskStatus.failed.value],
|
||||
status_extractor=lambda response: response.state.value,
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=result_url_extractor,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
poll_interval=16.0,
|
||||
max_poll_attempts=256,
|
||||
).execute()
|
||||
|
||||
|
||||
def get_video_url_from_response(response) -> Optional[str]:
|
||||
if response.creations:
|
||||
return response.creations[0].url
|
||||
return None
|
||||
|
||||
|
||||
def get_video_from_response(response) -> TaskResult:
|
||||
if not response.creations:
|
||||
error_msg = f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}"
|
||||
logging.info(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
logging.info("Vidu task %s succeeded. Video URL: %s", response.creations[0].id, response.creations[0].url)
|
||||
return response.creations[0]
|
||||
|
||||
|
||||
async def execute_task(
|
||||
vidu_endpoint: str,
|
||||
auth_kwargs: Optional[dict[str, str]],
|
||||
payload: TaskCreationRequest,
|
||||
estimated_duration: int,
|
||||
node_id: str,
|
||||
) -> R:
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=vidu_endpoint,
|
||||
method=HttpMethod.POST,
|
||||
request_model=TaskCreationRequest,
|
||||
response_model=TaskCreationResponse,
|
||||
),
|
||||
request=payload,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
if response.state == TaskStatus.failed:
|
||||
error_msg = f"Vidu request failed. Code: {response.code}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=VIDU_GET_GENERATION_STATUS % response.task_id,
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
|
||||
class ViduTextToVideoNode(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="ViduTextToVideoNode",
|
||||
display_name="Vidu Text To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate video from text prompt",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in VideoModelName],
|
||||
default=VideoModelName.vidu_q1.value,
|
||||
tooltip="Model name",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A textual description for video generation",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=5,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[model.value for model in AspectRatio],
|
||||
default=AspectRatio.r_16_9.value,
|
||||
tooltip="The aspect ratio of the output video",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"resolution",
|
||||
options=[model.value for model in Resolution],
|
||||
default=Resolution.r_1080p.value,
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=[model.value for model in MovementAmplitude],
|
||||
default=MovementAmplitude.auto.value,
|
||||
tooltip="The movement amplitude of objects in the frame",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
aspect_ratio: str,
|
||||
resolution: str,
|
||||
movement_amplitude: str,
|
||||
) -> comfy_io.NodeOutput:
|
||||
if not prompt:
|
||||
raise ValueError("The prompt field is required and cannot be empty.")
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
aspect_ratio=aspect_ratio,
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
results = await execute_task(VIDU_TEXT_TO_VIDEO, auth, payload, 320, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
class ViduImageToVideoNode(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="ViduImageToVideoNode",
|
||||
display_name="Vidu Image To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate video from image and optional prompt",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in VideoModelName],
|
||||
default=VideoModelName.vidu_q1.value,
|
||||
tooltip="Model name",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"image",
|
||||
tooltip="An image to be used as the start frame of the generated video",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="A textual description for video generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=5,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"resolution",
|
||||
options=[model.value for model in Resolution],
|
||||
default=Resolution.r_1080p.value,
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=[model.value for model in MovementAmplitude],
|
||||
default=MovementAmplitude.auto.value,
|
||||
tooltip="The movement amplitude of objects in the frame",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
movement_amplitude: str,
|
||||
) -> comfy_io.NodeOutput:
|
||||
if get_number_of_images(image) > 1:
|
||||
raise ValueError("Only one input image is allowed.")
|
||||
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
payload.images = await upload_images_to_comfyapi(
|
||||
image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
results = await execute_task(VIDU_IMAGE_TO_VIDEO, auth, payload, 120, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
class ViduReferenceVideoNode(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="ViduReferenceVideoNode",
|
||||
display_name="Vidu Reference To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate video from multiple images and prompt",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in VideoModelName],
|
||||
default=VideoModelName.vidu_q1.value,
|
||||
tooltip="Model name",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"images",
|
||||
tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A textual description for video generation",
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=5,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[model.value for model in AspectRatio],
|
||||
default=AspectRatio.r_16_9.value,
|
||||
tooltip="The aspect ratio of the output video",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"resolution",
|
||||
options=[model.value for model in Resolution],
|
||||
default=Resolution.r_1080p.value,
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=[model.value for model in MovementAmplitude],
|
||||
default=MovementAmplitude.auto.value,
|
||||
tooltip="The movement amplitude of objects in the frame",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
images: torch.Tensor,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
aspect_ratio: str,
|
||||
resolution: str,
|
||||
movement_amplitude: str,
|
||||
) -> comfy_io.NodeOutput:
|
||||
if not prompt:
|
||||
raise ValueError("The prompt field is required and cannot be empty.")
|
||||
a = get_number_of_images(images)
|
||||
if a > 7:
|
||||
raise ValueError("Too many images, maximum allowed is 7.")
|
||||
for image in images:
|
||||
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
|
||||
validate_image_dimensions(image, min_width=128, min_height=128)
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
aspect_ratio=aspect_ratio,
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
payload.images = await upload_images_to_comfyapi(
|
||||
images,
|
||||
max_images=7,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
results = await execute_task(VIDU_REFERENCE_VIDEO, auth, payload, 120, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
class ViduStartEndToVideoNode(comfy_io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return comfy_io.Schema(
|
||||
node_id="ViduStartEndToVideoNode",
|
||||
display_name="Vidu Start End To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate a video from start and end frames and a prompt",
|
||||
inputs=[
|
||||
comfy_io.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in VideoModelName],
|
||||
default=VideoModelName.vidu_q1.value,
|
||||
tooltip="Model name",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"first_frame",
|
||||
tooltip="Start frame",
|
||||
),
|
||||
comfy_io.Image.Input(
|
||||
"end_frame",
|
||||
tooltip="End frame",
|
||||
),
|
||||
comfy_io.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A textual description for video generation",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=5,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
tooltip="Duration of the output video in seconds",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=comfy_io.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation (0 for random)",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"resolution",
|
||||
options=[model.value for model in Resolution],
|
||||
default=Resolution.r_1080p.value,
|
||||
tooltip="Supported values may vary by model & duration",
|
||||
optional=True,
|
||||
),
|
||||
comfy_io.Combo.Input(
|
||||
"movement_amplitude",
|
||||
options=[model.value for model in MovementAmplitude],
|
||||
default=MovementAmplitude.auto.value,
|
||||
tooltip="The movement amplitude of objects in the frame",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
comfy_io.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
comfy_io.Hidden.auth_token_comfy_org,
|
||||
comfy_io.Hidden.api_key_comfy_org,
|
||||
comfy_io.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
first_frame: torch.Tensor,
|
||||
end_frame: torch.Tensor,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
movement_amplitude: str,
|
||||
) -> comfy_io.NodeOutput:
|
||||
validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
prompt=prompt,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
payload.images = [
|
||||
(await upload_images_to_comfyapi(frame, max_images=1, mime_type="image/png", auth_kwargs=auth))[0]
|
||||
for frame in (first_frame, end_frame)
|
||||
]
|
||||
results = await execute_task(VIDU_START_END_VIDEO, auth, payload, 96, cls.hidden.unique_id)
|
||||
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
class ViduExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||
return [
|
||||
ViduTextToVideoNode,
|
||||
ViduImageToVideoNode,
|
||||
ViduReferenceVideoNode,
|
||||
ViduStartEndToVideoNode,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> ViduExtension:
|
||||
return ViduExtension()
|
||||
@@ -53,6 +53,53 @@ def validate_image_aspect_ratio(
|
||||
)
|
||||
|
||||
|
||||
def validate_image_aspect_ratio_range(
|
||||
image: torch.Tensor,
|
||||
min_ratio: tuple[float, float], # e.g. (1, 4)
|
||||
max_ratio: tuple[float, float], # e.g. (4, 1)
|
||||
*,
|
||||
strict: bool = True, # True -> (min, max); False -> [min, max]
|
||||
) -> float:
|
||||
a1, b1 = min_ratio
|
||||
a2, b2 = max_ratio
|
||||
if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0:
|
||||
raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).")
|
||||
lo, hi = (a1 / b1), (a2 / b2)
|
||||
if lo > hi:
|
||||
lo, hi = hi, lo
|
||||
a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text
|
||||
w, h = get_image_dimensions(image)
|
||||
if w <= 0 or h <= 0:
|
||||
raise ValueError(f"Invalid image dimensions: {w}x{h}")
|
||||
ar = w / h
|
||||
ok = (lo < ar < hi) if strict else (lo <= ar <= hi)
|
||||
if not ok:
|
||||
op = "<" if strict else "≤"
|
||||
raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}")
|
||||
return ar
|
||||
|
||||
|
||||
def validate_aspect_ratio_closeness(
|
||||
start_img,
|
||||
end_img,
|
||||
min_rel: float,
|
||||
max_rel: float,
|
||||
*,
|
||||
strict: bool = False, # True => exclusive, False => inclusive
|
||||
) -> None:
|
||||
w1, h1 = get_image_dimensions(start_img)
|
||||
w2, h2 = get_image_dimensions(end_img)
|
||||
if min(w1, h1, w2, h2) <= 0:
|
||||
raise ValueError("Invalid image dimensions")
|
||||
ar1 = w1 / h1
|
||||
ar2 = w2 / h2
|
||||
# Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1)
|
||||
closeness = max(ar1, ar2) / min(ar1, ar2)
|
||||
limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25
|
||||
if (closeness >= limit) if strict else (closeness > limit):
|
||||
raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}–{max_rel}.")
|
||||
|
||||
|
||||
def validate_video_dimensions(
|
||||
video: VideoInput,
|
||||
min_width: Optional[int] = None,
|
||||
@@ -98,3 +145,9 @@ def validate_video_duration(
|
||||
raise ValueError(
|
||||
f"Video duration must be at most {max_duration}s, got {duration}s"
|
||||
)
|
||||
|
||||
|
||||
def get_number_of_images(images):
|
||||
if isinstance(images, torch.Tensor):
|
||||
return images.shape[0] if images.ndim >= 4 else 1
|
||||
return len(images)
|
||||
|
||||
@@ -1,49 +1,63 @@
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
class TextEncodeAceStepAudio:
|
||||
|
||||
class TextEncodeAceStepAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"tags": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TextEncodeAceStepAudio",
|
||||
category="conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
||||
io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def encode(self, clip, tags, lyrics, lyrics_strength):
|
||||
@classmethod
|
||||
def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
|
||||
return (conditioning, )
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
class EmptyAceStepLatentAudio:
|
||||
def __init__(self):
|
||||
self.device = comfy.model_management.intermediate_device()
|
||||
class EmptyAceStepLatentAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyAceStepLatentAudio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
|
||||
io.Int.Input(
|
||||
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||
),
|
||||
],
|
||||
outputs=[io.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/audio"
|
||||
|
||||
def generate(self, seconds, batch_size):
|
||||
def execute(cls, seconds, batch_size) -> io.NodeOutput:
|
||||
length = int(seconds * 44100 / 512 / 8)
|
||||
latent = torch.zeros([batch_size, 8, 16, length], device=self.device)
|
||||
return ({"samples": latent, "type": "audio"}, )
|
||||
latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TextEncodeAceStepAudio": TextEncodeAceStepAudio,
|
||||
"EmptyAceStepLatentAudio": EmptyAceStepLatentAudio,
|
||||
}
|
||||
class AceExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TextEncodeAceStepAudio,
|
||||
EmptyAceStepLatentAudio,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> AceExtension:
|
||||
return AceExtension()
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm.auto import trange
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.samplers
|
||||
import comfy.utils
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm.auto import trange
|
||||
from comfy.k_diffusion.sampling import to_d
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -33,30 +38,29 @@ def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable
|
||||
return x
|
||||
|
||||
|
||||
class SamplerLCMUpscale:
|
||||
upscale_methods = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
|
||||
class SamplerLCMUpscale(io.ComfyNode):
|
||||
UPSCALE_METHODS = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"scale_ratio": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.01}),
|
||||
"scale_steps": ("INT", {"default": -1, "min": -1, "max": 1000, "step": 1}),
|
||||
"upscale_method": (s.upscale_methods,),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
CATEGORY = "sampling/custom_sampling/samplers"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="SamplerLCMUpscale",
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01),
|
||||
io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1),
|
||||
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
|
||||
],
|
||||
outputs=[io.Sampler.Output()],
|
||||
)
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, scale_ratio, scale_steps, upscale_method):
|
||||
@classmethod
|
||||
def execute(cls, scale_ratio, scale_steps, upscale_method) -> io.NodeOutput:
|
||||
if scale_steps < 0:
|
||||
scale_steps = None
|
||||
sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method})
|
||||
return (sampler, )
|
||||
return io.NodeOutput(sampler)
|
||||
|
||||
from comfy.k_diffusion.sampling import to_d
|
||||
import comfy.model_patcher
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
@@ -82,30 +86,36 @@ def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
return x
|
||||
|
||||
|
||||
class SamplerEulerCFGpp:
|
||||
class SamplerEulerCFGpp(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"version": (["regular", "alternative"],),}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
# CATEGORY = "sampling/custom_sampling/samplers"
|
||||
CATEGORY = "_for_testing"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="SamplerEulerCFGpp",
|
||||
display_name="SamplerEulerCFG++",
|
||||
category="_for_testing", # "sampling/custom_sampling/samplers"
|
||||
inputs=[
|
||||
io.Combo.Input("version", options=["regular", "alternative"]),
|
||||
],
|
||||
outputs=[io.Sampler.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, version):
|
||||
@classmethod
|
||||
def execute(cls, version) -> io.NodeOutput:
|
||||
if version == "alternative":
|
||||
sampler = comfy.samplers.KSAMPLER(sample_euler_pp)
|
||||
else:
|
||||
sampler = comfy.samplers.ksampler("euler_cfg_pp")
|
||||
return (sampler, )
|
||||
return io.NodeOutput(sampler)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SamplerLCMUpscale": SamplerLCMUpscale,
|
||||
"SamplerEulerCFGpp": SamplerEulerCFGpp,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SamplerEulerCFGpp": "SamplerEulerCFG++",
|
||||
}
|
||||
class AdvancedSamplersExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SamplerLCMUpscale,
|
||||
SamplerEulerCFGpp,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> AdvancedSamplersExtension:
|
||||
return AdvancedSamplersExtension()
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
def project(v0, v1):
|
||||
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
|
||||
@@ -6,22 +10,45 @@ def project(v0, v1):
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
return v0_parallel, v0_orthogonal
|
||||
|
||||
class APG:
|
||||
class APG(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}),
|
||||
"norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}),
|
||||
"momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="APG",
|
||||
display_name="Adaptive Projected Guidance",
|
||||
category="sampling/custom_sampling",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input(
|
||||
"eta",
|
||||
default=1.0,
|
||||
min=-10.0,
|
||||
max=10.0,
|
||||
step=0.01,
|
||||
tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.",
|
||||
),
|
||||
io.Float.Input(
|
||||
"norm_threshold",
|
||||
default=5.0,
|
||||
min=0.0,
|
||||
max=50.0,
|
||||
step=0.1,
|
||||
tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.",
|
||||
),
|
||||
io.Float.Input(
|
||||
"momentum",
|
||||
default=0.0,
|
||||
min=-5.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.",
|
||||
),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
def patch(self, model, eta, norm_threshold, momentum):
|
||||
@classmethod
|
||||
def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput:
|
||||
running_avg = 0
|
||||
prev_sigma = None
|
||||
|
||||
@@ -65,12 +92,15 @@ class APG:
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
||||
return (m,)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"APG": APG,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"APG": "Adaptive Projected Guidance",
|
||||
}
|
||||
class ApgExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
APG,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> ApgExtension:
|
||||
return ApgExtension()
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
def attention_multiply(attn, model, q, k, v, out):
|
||||
m = model.clone()
|
||||
@@ -16,57 +20,71 @@ def attention_multiply(attn, model, q, k, v, out):
|
||||
return m
|
||||
|
||||
|
||||
class UNetSelfAttentionMultiply:
|
||||
class UNetSelfAttentionMultiply(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="UNetSelfAttentionMultiply",
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
CATEGORY = "_for_testing/attention_experiments"
|
||||
|
||||
def patch(self, model, q, k, v, out):
|
||||
@classmethod
|
||||
def execute(cls, model, q, k, v, out) -> io.NodeOutput:
|
||||
m = attention_multiply("attn1", model, q, k, v, out)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class UNetCrossAttentionMultiply:
|
||||
|
||||
class UNetCrossAttentionMultiply(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="UNetCrossAttentionMultiply",
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
CATEGORY = "_for_testing/attention_experiments"
|
||||
|
||||
def patch(self, model, q, k, v, out):
|
||||
@classmethod
|
||||
def execute(cls, model, q, k, v, out) -> io.NodeOutput:
|
||||
m = attention_multiply("attn2", model, q, k, v, out)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class CLIPAttentionMultiply:
|
||||
|
||||
class CLIPAttentionMultiply(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip": ("CLIP",),
|
||||
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="CLIPAttentionMultiply",
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Clip.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
CATEGORY = "_for_testing/attention_experiments"
|
||||
|
||||
def patch(self, clip, q, k, v, out):
|
||||
@classmethod
|
||||
def execute(cls, clip, q, k, v, out) -> io.NodeOutput:
|
||||
m = clip.clone()
|
||||
sd = m.patcher.model_state_dict()
|
||||
|
||||
@@ -79,23 +97,28 @@ class CLIPAttentionMultiply:
|
||||
m.add_patches({key: (None,)}, 0.0, v)
|
||||
if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"):
|
||||
m.add_patches({key: (None,)}, 0.0, out)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class UNetTemporalAttentionMultiply:
|
||||
|
||||
class UNetTemporalAttentionMultiply(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"self_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"self_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"cross_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"cross_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="UNetTemporalAttentionMultiply",
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("self_temporal", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("cross_structural", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Float.Input("cross_temporal", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
CATEGORY = "_for_testing/attention_experiments"
|
||||
|
||||
def patch(self, model, self_structural, self_temporal, cross_structural, cross_temporal):
|
||||
@classmethod
|
||||
def execute(cls, model, self_structural, self_temporal, cross_structural, cross_temporal) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
sd = model.model_state_dict()
|
||||
|
||||
@@ -110,11 +133,18 @@ class UNetTemporalAttentionMultiply:
|
||||
m.add_patches({k: (None,)}, 0.0, cross_temporal)
|
||||
else:
|
||||
m.add_patches({k: (None,)}, 0.0, cross_structural)
|
||||
return (m, )
|
||||
return io.NodeOutput(m)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"UNetSelfAttentionMultiply": UNetSelfAttentionMultiply,
|
||||
"UNetCrossAttentionMultiply": UNetCrossAttentionMultiply,
|
||||
"CLIPAttentionMultiply": CLIPAttentionMultiply,
|
||||
"UNetTemporalAttentionMultiply": UNetTemporalAttentionMultiply,
|
||||
}
|
||||
|
||||
class AttentionMultiplyExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
UNetSelfAttentionMultiply,
|
||||
UNetCrossAttentionMultiply,
|
||||
CLIPAttentionMultiply,
|
||||
UNetTemporalAttentionMultiply,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> AttentionMultiplyExtension:
|
||||
return AttentionMultiplyExtension()
|
||||
|
||||
459
comfy_extras/nodes_easycache.py
Normal file
459
comfy_extras/nodes_easycache.py
Normal file
@@ -0,0 +1,459 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from comfy_api.latest import io, ComfyExtension
|
||||
import comfy.patcher_extension
|
||||
import logging
|
||||
import torch
|
||||
import comfy.model_patcher
|
||||
if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
# get values from args
|
||||
x: torch.Tensor = args[0]
|
||||
transformer_options: dict[str] = args[-1]
|
||||
if not isinstance(transformer_options, dict):
|
||||
transformer_options = kwargs.get("transformer_options")
|
||||
if not transformer_options:
|
||||
transformer_options = args[-2]
|
||||
easycache: EasyCacheHolder = transformer_options["easycache"]
|
||||
sigmas = transformer_options["sigmas"]
|
||||
uuids = transformer_options["uuids"]
|
||||
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
|
||||
return executor(*args, **kwargs)
|
||||
# prepare next x_prev
|
||||
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
|
||||
next_x_prev = x
|
||||
input_change = None
|
||||
do_easycache = easycache.should_do_easycache(sigmas)
|
||||
if do_easycache:
|
||||
# if first cond marked this step for skipping, skip it and use appropriate cached values
|
||||
if easycache.skip_current_step:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
|
||||
return easycache.apply_cache_diff(x, uuids)
|
||||
if easycache.initial_step:
|
||||
easycache.first_cond_uuid = uuids[0]
|
||||
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
|
||||
easycache.initial_step = False
|
||||
if has_first_cond_uuid:
|
||||
if easycache.has_x_prev_subsampled():
|
||||
input_change = (easycache.subsample(x, uuids, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
|
||||
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
|
||||
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||
easycache.cumulative_change_rate += approx_output_change_rate
|
||||
if easycache.cumulative_change_rate < easycache.reuse_threshold:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||
# other conds should also skip this step, and instead use their cached values
|
||||
easycache.skip_current_step = True
|
||||
return easycache.apply_cache_diff(x, uuids)
|
||||
else:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||
easycache.cumulative_change_rate = 0.0
|
||||
|
||||
output: torch.Tensor = executor(*args, **kwargs)
|
||||
if has_first_cond_uuid and easycache.has_output_prev_norm():
|
||||
output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
|
||||
if easycache.verbose:
|
||||
output_change_rate = output_change / easycache.output_prev_norm
|
||||
easycache.output_change_rates.append(output_change_rate.item())
|
||||
if easycache.has_relative_transformation_rate():
|
||||
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
|
||||
if input_change is not None:
|
||||
easycache.relative_transformation_rate = output_change / input_change
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
|
||||
# TODO: allow cache_diff to be offloaded
|
||||
easycache.update_cache_diff(output, next_x_prev, uuids)
|
||||
if has_first_cond_uuid:
|
||||
easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids)
|
||||
easycache.output_prev_subsampled = easycache.subsample(output, uuids)
|
||||
easycache.output_prev_norm = output.flatten().abs().mean()
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
||||
return output
|
||||
|
||||
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
||||
# get values from args
|
||||
x: torch.Tensor = args[0]
|
||||
timestep: float = args[1]
|
||||
model_options: dict[str] = args[2]
|
||||
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
|
||||
if easycache.is_past_end_timestep(timestep):
|
||||
return executor(*args, **kwargs)
|
||||
# prepare next x_prev
|
||||
next_x_prev = x
|
||||
input_change = None
|
||||
do_easycache = easycache.should_do_easycache(timestep)
|
||||
if do_easycache:
|
||||
if easycache.has_x_prev_subsampled():
|
||||
if easycache.has_x_prev_subsampled():
|
||||
input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
|
||||
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
|
||||
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||
easycache.cumulative_change_rate += approx_output_change_rate
|
||||
if easycache.cumulative_change_rate < easycache.reuse_threshold:
|
||||
if easycache.verbose:
|
||||
logging.info(f"LazyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||
# other conds should also skip this step, and instead use their cached values
|
||||
easycache.skip_current_step = True
|
||||
return easycache.apply_cache_diff(x)
|
||||
else:
|
||||
if easycache.verbose:
|
||||
logging.info(f"LazyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||
easycache.cumulative_change_rate = 0.0
|
||||
output: torch.Tensor = executor(*args, **kwargs)
|
||||
if easycache.has_output_prev_norm():
|
||||
output_change = (easycache.subsample(output, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
|
||||
if easycache.verbose:
|
||||
output_change_rate = output_change / easycache.output_prev_norm
|
||||
easycache.output_change_rates.append(output_change_rate.item())
|
||||
if easycache.has_relative_transformation_rate():
|
||||
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
|
||||
if easycache.verbose:
|
||||
logging.info(f"LazyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
|
||||
if input_change is not None:
|
||||
easycache.relative_transformation_rate = output_change / input_change
|
||||
if easycache.verbose:
|
||||
logging.info(f"LazyCache [verbose] - output_change_rate: {output_change_rate}")
|
||||
# TODO: allow cache_diff to be offloaded
|
||||
easycache.update_cache_diff(output, next_x_prev)
|
||||
easycache.x_prev_subsampled = easycache.subsample(next_x_prev)
|
||||
easycache.output_prev_subsampled = easycache.subsample(output)
|
||||
easycache.output_prev_norm = output.flatten().abs().mean()
|
||||
if easycache.verbose:
|
||||
logging.info(f"LazyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
||||
return output
|
||||
|
||||
def easycache_calc_cond_batch_wrapper(executor, *args, **kwargs):
|
||||
model_options = args[-1]
|
||||
easycache: EasyCacheHolder = model_options["transformer_options"]["easycache"]
|
||||
easycache.skip_current_step = False
|
||||
# TODO: check if first_cond_uuid is active at this timestep; otherwise, EasyCache needs to be partially reset
|
||||
return executor(*args, **kwargs)
|
||||
|
||||
def easycache_sample_wrapper(executor, *args, **kwargs):
|
||||
"""
|
||||
This OUTER_SAMPLE wrapper makes sure easycache is prepped for current run, and all memory usage is cleared at the end.
|
||||
"""
|
||||
try:
|
||||
guider = executor.class_obj
|
||||
orig_model_options = guider.model_options
|
||||
guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options)
|
||||
# clone and prepare timesteps
|
||||
guider.model_options["transformer_options"]["easycache"] = guider.model_options["transformer_options"]["easycache"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling)
|
||||
easycache: Union[EasyCacheHolder, LazyCacheHolder] = guider.model_options['transformer_options']['easycache']
|
||||
logging.info(f"{easycache.name} enabled - threshold: {easycache.reuse_threshold}, start_percent: {easycache.start_percent}, end_percent: {easycache.end_percent}")
|
||||
return executor(*args, **kwargs)
|
||||
finally:
|
||||
easycache = guider.model_options['transformer_options']['easycache']
|
||||
output_change_rates = easycache.output_change_rates
|
||||
approx_output_change_rates = easycache.approx_output_change_rates
|
||||
if easycache.verbose:
|
||||
logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}")
|
||||
logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}")
|
||||
total_steps = len(args[3])-1
|
||||
logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({total_steps/(total_steps-easycache.total_steps_skipped):.2f}x speedup).")
|
||||
easycache.reset()
|
||||
guider.model_options = orig_model_options
|
||||
|
||||
|
||||
class EasyCacheHolder:
|
||||
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
|
||||
self.name = "EasyCache"
|
||||
self.reuse_threshold = reuse_threshold
|
||||
self.start_percent = start_percent
|
||||
self.end_percent = end_percent
|
||||
self.subsample_factor = subsample_factor
|
||||
self.offload_cache_diff = offload_cache_diff
|
||||
self.verbose = verbose
|
||||
# timestep values
|
||||
self.start_t = 0.0
|
||||
self.end_t = 0.0
|
||||
# control values
|
||||
self.relative_transformation_rate: float = None
|
||||
self.cumulative_change_rate = 0.0
|
||||
self.initial_step = True
|
||||
self.skip_current_step = False
|
||||
# cache values
|
||||
self.first_cond_uuid = None
|
||||
self.x_prev_subsampled: torch.Tensor = None
|
||||
self.output_prev_subsampled: torch.Tensor = None
|
||||
self.output_prev_norm: torch.Tensor = None
|
||||
self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {}
|
||||
self.output_change_rates = []
|
||||
self.approx_output_change_rates = []
|
||||
self.total_steps_skipped = 0
|
||||
# how to deal with mismatched dims
|
||||
self.allow_mismatch = True
|
||||
self.cut_from_start = True
|
||||
|
||||
def is_past_end_timestep(self, timestep: float) -> bool:
|
||||
return not (timestep[0] > self.end_t).item()
|
||||
|
||||
def should_do_easycache(self, timestep: float) -> bool:
|
||||
return (timestep[0] <= self.start_t).item()
|
||||
|
||||
def has_x_prev_subsampled(self) -> bool:
|
||||
return self.x_prev_subsampled is not None
|
||||
|
||||
def has_output_prev_subsampled(self) -> bool:
|
||||
return self.output_prev_subsampled is not None
|
||||
|
||||
def has_output_prev_norm(self) -> bool:
|
||||
return self.output_prev_norm is not None
|
||||
|
||||
def has_relative_transformation_rate(self) -> bool:
|
||||
return self.relative_transformation_rate is not None
|
||||
|
||||
def prepare_timesteps(self, model_sampling):
|
||||
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
|
||||
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
|
||||
return self
|
||||
|
||||
def subsample(self, x: torch.Tensor, uuids: list[UUID], clone: bool = True) -> torch.Tensor:
|
||||
batch_offset = x.shape[0] // len(uuids)
|
||||
uuid_idx = uuids.index(self.first_cond_uuid)
|
||||
if self.subsample_factor > 1:
|
||||
to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ..., ::self.subsample_factor, ::self.subsample_factor]
|
||||
if clone:
|
||||
return to_return.clone()
|
||||
return to_return
|
||||
to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ...]
|
||||
if clone:
|
||||
return to_return.clone()
|
||||
return to_return
|
||||
|
||||
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
|
||||
if self.first_cond_uuid in uuids:
|
||||
self.total_steps_skipped += 1
|
||||
batch_offset = x.shape[0] // len(uuids)
|
||||
for i, uuid in enumerate(uuids):
|
||||
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
||||
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
|
||||
if not self.allow_mismatch:
|
||||
raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good")
|
||||
slicing = []
|
||||
skip_this_dim = True
|
||||
for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape):
|
||||
if skip_this_dim:
|
||||
skip_this_dim = False
|
||||
continue
|
||||
if dim_u != dim_x:
|
||||
if self.cut_from_start:
|
||||
slicing.append(slice(dim_x-dim_u, None))
|
||||
else:
|
||||
slicing.append(slice(None, dim_u))
|
||||
else:
|
||||
slicing.append(slice(None))
|
||||
slicing = [slice(i*batch_offset,(i+1)*batch_offset)] + slicing
|
||||
x = x[slicing]
|
||||
x += self.uuid_cache_diffs[uuid].to(x.device)
|
||||
return x
|
||||
|
||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
|
||||
# if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
||||
if output.shape[1:] != x.shape[1:]:
|
||||
if not self.allow_mismatch:
|
||||
raise ValueError(f"Output dims {output.shape} don't match x dims {x.shape} - this is no good")
|
||||
slicing = []
|
||||
skip_dim = True
|
||||
for dim_o, dim_x in zip(output.shape, x.shape):
|
||||
if not skip_dim and dim_o != dim_x:
|
||||
if self.cut_from_start:
|
||||
slicing.append(slice(dim_x-dim_o, None))
|
||||
else:
|
||||
slicing.append(slice(None, dim_o))
|
||||
else:
|
||||
slicing.append(slice(None))
|
||||
skip_dim = False
|
||||
x = x[slicing]
|
||||
diff = output - x
|
||||
batch_offset = diff.shape[0] // len(uuids)
|
||||
for i, uuid in enumerate(uuids):
|
||||
self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
|
||||
|
||||
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
|
||||
return self.first_cond_uuid in uuids
|
||||
|
||||
def reset(self):
|
||||
self.relative_transformation_rate = 0.0
|
||||
self.cumulative_change_rate = 0.0
|
||||
self.initial_step = True
|
||||
self.skip_current_step = False
|
||||
self.output_change_rates = []
|
||||
self.first_cond_uuid = None
|
||||
del self.x_prev_subsampled
|
||||
self.x_prev_subsampled = None
|
||||
del self.output_prev_subsampled
|
||||
self.output_prev_subsampled = None
|
||||
del self.output_prev_norm
|
||||
self.output_prev_norm = None
|
||||
del self.uuid_cache_diffs
|
||||
self.uuid_cache_diffs = {}
|
||||
self.total_steps_skipped = 0
|
||||
return self
|
||||
|
||||
def clone(self):
|
||||
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
|
||||
|
||||
|
||||
class EasyCacheNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="EasyCache",
|
||||
display_name="EasyCache",
|
||||
description="Native EasyCache implementation.",
|
||||
category="advanced/debug/model",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to add EasyCache to."),
|
||||
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."),
|
||||
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of EasyCache."),
|
||||
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of EasyCache."),
|
||||
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with EasyCache."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
||||
model = model.clone()
|
||||
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
class LazyCacheHolder:
|
||||
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
|
||||
self.name = "LazyCache"
|
||||
self.reuse_threshold = reuse_threshold
|
||||
self.start_percent = start_percent
|
||||
self.end_percent = end_percent
|
||||
self.subsample_factor = subsample_factor
|
||||
self.offload_cache_diff = offload_cache_diff
|
||||
self.verbose = verbose
|
||||
# timestep values
|
||||
self.start_t = 0.0
|
||||
self.end_t = 0.0
|
||||
# control values
|
||||
self.relative_transformation_rate: float = None
|
||||
self.cumulative_change_rate = 0.0
|
||||
self.initial_step = True
|
||||
# cache values
|
||||
self.x_prev_subsampled: torch.Tensor = None
|
||||
self.output_prev_subsampled: torch.Tensor = None
|
||||
self.output_prev_norm: torch.Tensor = None
|
||||
self.cache_diff: torch.Tensor = None
|
||||
self.output_change_rates = []
|
||||
self.approx_output_change_rates = []
|
||||
self.total_steps_skipped = 0
|
||||
|
||||
def has_cache_diff(self) -> bool:
|
||||
return self.cache_diff is not None
|
||||
|
||||
def is_past_end_timestep(self, timestep: float) -> bool:
|
||||
return not (timestep[0] > self.end_t).item()
|
||||
|
||||
def should_do_easycache(self, timestep: float) -> bool:
|
||||
return (timestep[0] <= self.start_t).item()
|
||||
|
||||
def has_x_prev_subsampled(self) -> bool:
|
||||
return self.x_prev_subsampled is not None
|
||||
|
||||
def has_output_prev_subsampled(self) -> bool:
|
||||
return self.output_prev_subsampled is not None
|
||||
|
||||
def has_output_prev_norm(self) -> bool:
|
||||
return self.output_prev_norm is not None
|
||||
|
||||
def has_relative_transformation_rate(self) -> bool:
|
||||
return self.relative_transformation_rate is not None
|
||||
|
||||
def prepare_timesteps(self, model_sampling):
|
||||
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
|
||||
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
|
||||
return self
|
||||
|
||||
def subsample(self, x: torch.Tensor, clone: bool = True) -> torch.Tensor:
|
||||
if self.subsample_factor > 1:
|
||||
to_return = x[..., ::self.subsample_factor, ::self.subsample_factor]
|
||||
if clone:
|
||||
return to_return.clone()
|
||||
return to_return
|
||||
if clone:
|
||||
return x.clone()
|
||||
return x
|
||||
|
||||
def apply_cache_diff(self, x: torch.Tensor):
|
||||
self.total_steps_skipped += 1
|
||||
return x + self.cache_diff.to(x.device)
|
||||
|
||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor):
|
||||
self.cache_diff = output - x
|
||||
|
||||
def reset(self):
|
||||
self.relative_transformation_rate = 0.0
|
||||
self.cumulative_change_rate = 0.0
|
||||
self.initial_step = True
|
||||
self.output_change_rates = []
|
||||
self.approx_output_change_rates = []
|
||||
del self.cache_diff
|
||||
self.cache_diff = None
|
||||
self.total_steps_skipped = 0
|
||||
return self
|
||||
|
||||
def clone(self):
|
||||
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
|
||||
|
||||
class LazyCacheNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LazyCache",
|
||||
display_name="LazyCache",
|
||||
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
|
||||
category="advanced/debug/model",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to add LazyCache to."),
|
||||
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."),
|
||||
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of LazyCache."),
|
||||
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of LazyCache."),
|
||||
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with LazyCache."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
||||
model = model.clone()
|
||||
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
class EasyCacheExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
EasyCacheNode,
|
||||
LazyCacheNode,
|
||||
]
|
||||
|
||||
def comfy_entrypoint():
|
||||
return EasyCacheExtension()
|
||||
@@ -166,7 +166,7 @@ class LTXVAddGuide:
|
||||
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
||||
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1),
|
||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
||||
1.0 - strength,
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
from comfy_api.latest import io, ComfyExtension
|
||||
|
||||
class MemoryReserveNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="ReserveAdditionalMemory",
|
||||
display_name="Reserve Additional Memory",
|
||||
description="Adds additional expected memory usage for the model, in gigabytes.",
|
||||
category="advanced/debug/model",
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to add memory reserve to."),
|
||||
io.Float.Input("memory_reserve_gb", min=0.0, default=0.0, max=2048.0, step=0.1, tooltip="The additional expected memory usage for the model, in gigabytes."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with the additional memory reserve."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, memory_reserve_gb: float) -> io.NodeOutput:
|
||||
model = model.clone()
|
||||
model.add_model_memory_reserve(memory_reserve_gb)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
class MemoryReserveExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
MemoryReserveNode,
|
||||
]
|
||||
|
||||
def comfy_entrypoint():
|
||||
return MemoryReserveExtension()
|
||||
161
comfy_extras/nodes_model_patch.py
Normal file
161
comfy_extras/nodes_model_patch.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import torch
|
||||
import folder_paths
|
||||
import comfy.utils
|
||||
import comfy.ops
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.latent_formats
|
||||
|
||||
|
||||
class BlockWiseControlBlock(torch.nn.Module):
|
||||
# [linear, gelu, linear]
|
||||
def __init__(self, dim: int = 3072, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.x_rms = operations.RMSNorm(dim, eps=1e-6)
|
||||
self.y_rms = operations.RMSNorm(dim, eps=1e-6)
|
||||
self.input_proj = operations.Linear(dim, dim)
|
||||
self.act = torch.nn.GELU()
|
||||
self.output_proj = operations.Linear(dim, dim)
|
||||
|
||||
def forward(self, x, y):
|
||||
x, y = self.x_rms(x), self.y_rms(y)
|
||||
x = self.input_proj(x + y)
|
||||
x = self.act(x)
|
||||
x = self.output_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class QwenImageBlockWiseControlNet(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 60,
|
||||
in_dim: int = 64,
|
||||
additional_in_dim: int = 0,
|
||||
dim: int = 3072,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.additional_in_dim = additional_in_dim
|
||||
self.img_in = operations.Linear(in_dim + additional_in_dim, dim, device=device, dtype=dtype)
|
||||
self.controlnet_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
BlockWiseControlBlock(dim, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def process_input_latent_image(self, latent_image):
|
||||
latent_image[:, :16] = comfy.latent_formats.Wan21().process_in(latent_image[:, :16])
|
||||
patch_size = 2
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(latent_image, (1, patch_size, patch_size))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
||||
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||
return self.img_in(hidden_states)
|
||||
|
||||
def control_block(self, img, controlnet_conditioning, block_id):
|
||||
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
|
||||
|
||||
|
||||
class ModelPatchLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "name": (folder_paths.get_filename_list("model_patches"), ),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL_PATCH",)
|
||||
FUNCTION = "load_model_patch"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_model_patch(self, name):
|
||||
model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name)
|
||||
sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True)
|
||||
dtype = comfy.utils.weight_dtype(sd)
|
||||
# TODO: this node will work with more types of model patches
|
||||
additional_in_dim = sd["img_in.weight"].shape[1] - 64
|
||||
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||
model.load_state_dict(sd)
|
||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
return (model,)
|
||||
|
||||
|
||||
class DiffSynthCnetPatch:
|
||||
def __init__(self, model_patch, vae, image, strength, mask=None):
|
||||
self.model_patch = model_patch
|
||||
self.vae = vae
|
||||
self.image = image
|
||||
self.strength = strength
|
||||
self.mask = mask
|
||||
self.encoded_image = model_patch.model.process_input_latent_image(self.encode_latent_cond(image))
|
||||
|
||||
def encode_latent_cond(self, image):
|
||||
latent_image = self.vae.encode(image)
|
||||
if self.model_patch.model.additional_in_dim > 0:
|
||||
if self.mask is None:
|
||||
mask_ = torch.ones_like(latent_image)[:, :self.model_patch.model.additional_in_dim // 4]
|
||||
else:
|
||||
mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none")
|
||||
|
||||
return torch.cat([latent_image, mask_], dim=1)
|
||||
else:
|
||||
return latent_image
|
||||
|
||||
def __call__(self, kwargs):
|
||||
x = kwargs.get("x")
|
||||
img = kwargs.get("img")
|
||||
block_index = kwargs.get("block_index")
|
||||
if self.encoded_image is None or self.encoded_image.shape[1:] != img.shape[1:]:
|
||||
spacial_compression = self.vae.spacial_compression_encode()
|
||||
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1)))
|
||||
comfy.model_management.load_models_gpu(loaded_models)
|
||||
|
||||
img = img + (self.model_patch.model.control_block(img, self.encoded_image.to(img.dtype), block_index) * self.strength)
|
||||
kwargs['img'] = img
|
||||
return kwargs
|
||||
|
||||
def to(self, device_or_dtype):
|
||||
if isinstance(device_or_dtype, torch.device):
|
||||
self.encoded_image = self.encoded_image.to(device_or_dtype)
|
||||
return self
|
||||
|
||||
def models(self):
|
||||
return [self.model_patch]
|
||||
|
||||
class QwenImageDiffsynthControlnet:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"model_patch": ("MODEL_PATCH",),
|
||||
"vae": ("VAE",),
|
||||
"image": ("IMAGE",),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {"mask": ("MASK",)}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "diffsynth_controlnet"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
CATEGORY = "advanced/loaders/qwen"
|
||||
|
||||
def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None):
|
||||
model_patched = model.clone()
|
||||
image = image[:, :, :, :3]
|
||||
if mask is not None:
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
if mask.ndim == 4:
|
||||
mask = mask.unsqueeze(2)
|
||||
mask = 1.0 - mask
|
||||
|
||||
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
||||
return (model_patched,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelPatchLoader": ModelPatchLoader,
|
||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||
}
|
||||
48
comfy_extras/nodes_qwen.py
Normal file
48
comfy_extras/nodes_qwen.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import node_helpers
|
||||
import comfy.utils
|
||||
import math
|
||||
|
||||
|
||||
class TextEncodeQwenImageEdit:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
},
|
||||
"optional": {"vae": ("VAE", ),
|
||||
"image": ("IMAGE", ),}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "advanced/conditioning"
|
||||
|
||||
def encode(self, clip, prompt, vae=None, image=None):
|
||||
ref_latent = None
|
||||
if image is None:
|
||||
images = []
|
||||
else:
|
||||
samples = image.movedim(-1, 1)
|
||||
total = int(1024 * 1024)
|
||||
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
|
||||
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
|
||||
image = s.movedim(1, -1)
|
||||
images = [image[:, :, :, :3]]
|
||||
if vae is not None:
|
||||
ref_latent = vae.encode(image[:, :, :, :3])
|
||||
|
||||
tokens = clip.tokenize(prompt, images=images)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
if ref_latent is not None:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True)
|
||||
return (conditioning, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TextEncodeQwenImageEdit": TextEncodeQwenImageEdit,
|
||||
}
|
||||
@@ -1,77 +1,91 @@
|
||||
import re
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
class StringConcatenate():
|
||||
|
||||
class StringConcatenate(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string_a": (IO.STRING, {"multiline": True}),
|
||||
"string_b": (IO.STRING, {"multiline": True}),
|
||||
"delimiter": (IO.STRING, {"multiline": False, "default": ""})
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringConcatenate",
|
||||
display_name="Concatenate",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string_a", multiline=True),
|
||||
io.String.Input("string_b", multiline=True),
|
||||
io.String.Input("delimiter", multiline=False, default=""),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string_a, string_b, delimiter, **kwargs):
|
||||
return delimiter.join((string_a, string_b)),
|
||||
|
||||
class StringSubstring():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"start": (IO.INT, {}),
|
||||
"end": (IO.INT, {}),
|
||||
}
|
||||
}
|
||||
def execute(cls, string_a, string_b, delimiter):
|
||||
return io.NodeOutput(delimiter.join((string_a, string_b)))
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, start, end, **kwargs):
|
||||
return string[start:end],
|
||||
|
||||
class StringLength():
|
||||
class StringSubstring(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True})
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringSubstring",
|
||||
display_name="Substring",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.Int.Input("start"),
|
||||
io.Int.Input("end"),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.INT,)
|
||||
RETURN_NAMES = ("length",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, **kwargs):
|
||||
length = len(string)
|
||||
|
||||
return length,
|
||||
|
||||
class CaseConverter():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["UPPERCASE", "lowercase", "Capitalize", "Title Case"]})
|
||||
}
|
||||
}
|
||||
def execute(cls, string, start, end):
|
||||
return io.NodeOutput(string[start:end])
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, mode, **kwargs):
|
||||
class StringLength(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringLength",
|
||||
display_name="Length",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Int.Output(display_name="length"),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, string):
|
||||
return io.NodeOutput(len(string))
|
||||
|
||||
|
||||
class CaseConverter(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CaseConverter",
|
||||
display_name="Case Converter",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.Combo.Input("mode", options=["UPPERCASE", "lowercase", "Capitalize", "Title Case"]),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, string, mode):
|
||||
if mode == "UPPERCASE":
|
||||
result = string.upper()
|
||||
elif mode == "lowercase":
|
||||
@@ -83,24 +97,27 @@ class CaseConverter():
|
||||
else:
|
||||
result = string
|
||||
|
||||
return result,
|
||||
return io.NodeOutput(result)
|
||||
|
||||
|
||||
class StringTrim():
|
||||
class StringTrim(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["Both", "Left", "Right"]})
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringTrim",
|
||||
display_name="Trim",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.Combo.Input("mode", options=["Both", "Left", "Right"]),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, mode, **kwargs):
|
||||
@classmethod
|
||||
def execute(cls, string, mode):
|
||||
if mode == "Both":
|
||||
result = string.strip()
|
||||
elif mode == "Left":
|
||||
@@ -110,70 +127,78 @@ class StringTrim():
|
||||
else:
|
||||
result = string
|
||||
|
||||
return result,
|
||||
return io.NodeOutput(result)
|
||||
|
||||
class StringReplace():
|
||||
|
||||
class StringReplace(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"find": (IO.STRING, {"multiline": True}),
|
||||
"replace": (IO.STRING, {"multiline": True})
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringReplace",
|
||||
display_name="Replace",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.String.Input("find", multiline=True),
|
||||
io.String.Input("replace", multiline=True),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, find, replace, **kwargs):
|
||||
result = string.replace(find, replace)
|
||||
return result,
|
||||
|
||||
|
||||
class StringContains():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"substring": (IO.STRING, {"multiline": True}),
|
||||
"case_sensitive": (IO.BOOLEAN, {"default": True})
|
||||
}
|
||||
}
|
||||
def execute(cls, string, find, replace):
|
||||
return io.NodeOutput(string.replace(find, replace))
|
||||
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
RETURN_NAMES = ("contains",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, substring, case_sensitive, **kwargs):
|
||||
class StringContains(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringContains",
|
||||
display_name="Contains",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.String.Input("substring", multiline=True),
|
||||
io.Boolean.Input("case_sensitive", default=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Boolean.Output(display_name="contains"),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, string, substring, case_sensitive):
|
||||
if case_sensitive:
|
||||
contains = substring in string
|
||||
else:
|
||||
contains = substring.lower() in string.lower()
|
||||
|
||||
return contains,
|
||||
return io.NodeOutput(contains)
|
||||
|
||||
|
||||
class StringCompare():
|
||||
class StringCompare(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string_a": (IO.STRING, {"multiline": True}),
|
||||
"string_b": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["Starts With", "Ends With", "Equal"]}),
|
||||
"case_sensitive": (IO.BOOLEAN, {"default": True})
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringCompare",
|
||||
display_name="Compare",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string_a", multiline=True),
|
||||
io.String.Input("string_b", multiline=True),
|
||||
io.Combo.Input("mode", options=["Starts With", "Ends With", "Equal"]),
|
||||
io.Boolean.Input("case_sensitive", default=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Boolean.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string_a, string_b, mode, case_sensitive, **kwargs):
|
||||
@classmethod
|
||||
def execute(cls, string_a, string_b, mode, case_sensitive):
|
||||
if case_sensitive:
|
||||
a = string_a
|
||||
b = string_b
|
||||
@@ -182,31 +207,34 @@ class StringCompare():
|
||||
b = string_b.lower()
|
||||
|
||||
if mode == "Equal":
|
||||
return a == b,
|
||||
return io.NodeOutput(a == b)
|
||||
elif mode == "Starts With":
|
||||
return a.startswith(b),
|
||||
return io.NodeOutput(a.startswith(b))
|
||||
elif mode == "Ends With":
|
||||
return a.endswith(b),
|
||||
return io.NodeOutput(a.endswith(b))
|
||||
|
||||
class RegexMatch():
|
||||
|
||||
class RegexMatch(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"regex_pattern": (IO.STRING, {"multiline": True}),
|
||||
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
||||
"multiline": (IO.BOOLEAN, {"default": False}),
|
||||
"dotall": (IO.BOOLEAN, {"default": False})
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexMatch",
|
||||
display_name="Regex Match",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.String.Input("regex_pattern", multiline=True),
|
||||
io.Boolean.Input("case_insensitive", default=True),
|
||||
io.Boolean.Input("multiline", default=False),
|
||||
io.Boolean.Input("dotall", default=False),
|
||||
],
|
||||
outputs=[
|
||||
io.Boolean.Output(display_name="matches"),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
RETURN_NAMES = ("matches",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, regex_pattern, case_insensitive, multiline, dotall, **kwargs):
|
||||
@classmethod
|
||||
def execute(cls, string, regex_pattern, case_insensitive, multiline, dotall):
|
||||
flags = 0
|
||||
|
||||
if case_insensitive:
|
||||
@@ -223,29 +251,32 @@ class RegexMatch():
|
||||
except re.error:
|
||||
result = False
|
||||
|
||||
return result,
|
||||
return io.NodeOutput(result)
|
||||
|
||||
|
||||
class RegexExtract():
|
||||
class RegexExtract(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"regex_pattern": (IO.STRING, {"multiline": True}),
|
||||
"mode": (IO.COMBO, {"options": ["First Match", "All Matches", "First Group", "All Groups"]}),
|
||||
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
||||
"multiline": (IO.BOOLEAN, {"default": False}),
|
||||
"dotall": (IO.BOOLEAN, {"default": False}),
|
||||
"group_index": (IO.INT, {"default": 1, "min": 0, "max": 100})
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexExtract",
|
||||
display_name="Regex Extract",
|
||||
category="utils/string",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.String.Input("regex_pattern", multiline=True),
|
||||
io.Combo.Input("mode", options=["First Match", "All Matches", "First Group", "All Groups"]),
|
||||
io.Boolean.Input("case_insensitive", default=True),
|
||||
io.Boolean.Input("multiline", default=False),
|
||||
io.Boolean.Input("dotall", default=False),
|
||||
io.Int.Input("group_index", default=1, min=0, max=100),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index, **kwargs):
|
||||
@classmethod
|
||||
def execute(cls, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index):
|
||||
join_delimiter = "\n"
|
||||
|
||||
flags = 0
|
||||
@@ -294,32 +325,33 @@ class RegexExtract():
|
||||
except re.error:
|
||||
result = ""
|
||||
|
||||
return result,
|
||||
return io.NodeOutput(result)
|
||||
|
||||
|
||||
class RegexReplace():
|
||||
DESCRIPTION = "Find and replace text using regex patterns."
|
||||
class RegexReplace(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"string": (IO.STRING, {"multiline": True}),
|
||||
"regex_pattern": (IO.STRING, {"multiline": True}),
|
||||
"replace": (IO.STRING, {"multiline": True}),
|
||||
},
|
||||
"optional": {
|
||||
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
||||
"multiline": (IO.BOOLEAN, {"default": False}),
|
||||
"dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}),
|
||||
"count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexReplace",
|
||||
display_name="Regex Replace",
|
||||
category="utils/string",
|
||||
description="Find and replace text using regex patterns.",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.String.Input("regex_pattern", multiline=True),
|
||||
io.String.Input("replace", multiline=True),
|
||||
io.Boolean.Input("case_insensitive", default=True, optional=True),
|
||||
io.Boolean.Input("multiline", default=False, optional=True),
|
||||
io.Boolean.Input("dotall", default=False, optional=True, tooltip="When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."),
|
||||
io.Int.Input("count", default=0, min=0, max=100, optional=True, tooltip="Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.STRING,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "utils/string"
|
||||
|
||||
def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs):
|
||||
@classmethod
|
||||
def execute(cls, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0):
|
||||
flags = 0
|
||||
|
||||
if case_insensitive:
|
||||
@@ -329,32 +361,25 @@ class RegexReplace():
|
||||
if dotall:
|
||||
flags |= re.DOTALL
|
||||
result = re.sub(regex_pattern, replace, string, count=count, flags=flags)
|
||||
return result,
|
||||
return io.NodeOutput(result)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StringConcatenate": StringConcatenate,
|
||||
"StringSubstring": StringSubstring,
|
||||
"StringLength": StringLength,
|
||||
"CaseConverter": CaseConverter,
|
||||
"StringTrim": StringTrim,
|
||||
"StringReplace": StringReplace,
|
||||
"StringContains": StringContains,
|
||||
"StringCompare": StringCompare,
|
||||
"RegexMatch": RegexMatch,
|
||||
"RegexExtract": RegexExtract,
|
||||
"RegexReplace": RegexReplace,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StringConcatenate": "Concatenate",
|
||||
"StringSubstring": "Substring",
|
||||
"StringLength": "Length",
|
||||
"CaseConverter": "Case Converter",
|
||||
"StringTrim": "Trim",
|
||||
"StringReplace": "Replace",
|
||||
"StringContains": "Contains",
|
||||
"StringCompare": "Compare",
|
||||
"RegexMatch": "Regex Match",
|
||||
"RegexExtract": "Regex Extract",
|
||||
"RegexReplace": "Regex Replace",
|
||||
}
|
||||
class StringExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
StringConcatenate,
|
||||
StringSubstring,
|
||||
StringLength,
|
||||
CaseConverter,
|
||||
StringTrim,
|
||||
StringReplace,
|
||||
StringContains,
|
||||
StringCompare,
|
||||
RegexMatch,
|
||||
RegexExtract,
|
||||
RegexReplace,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> StringExtension:
|
||||
return StringExtension()
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.50"
|
||||
__version__ = "0.3.52"
|
||||
|
||||
@@ -46,6 +46,8 @@ folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")]
|
||||
|
||||
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
|
||||
|
||||
folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patches")], supported_pt_extensions)
|
||||
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
temp_directory = os.path.join(base_path, "temp")
|
||||
input_directory = os.path.join(base_path, "input")
|
||||
|
||||
0
models/model_patches/put_model_patches_here
Normal file
0
models/model_patches/put_model_patches_here
Normal file
5
nodes.py
5
nodes.py
@@ -2321,7 +2321,9 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_edit_model.py",
|
||||
"nodes_tcfg.py",
|
||||
"nodes_context_windows.py",
|
||||
"nodes_memory_reserve.py",
|
||||
"nodes_qwen.py",
|
||||
"nodes_model_patch.py",
|
||||
"nodes_easycache.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
@@ -2351,6 +2353,7 @@ async def init_builtin_api_nodes():
|
||||
"nodes_moonvalley.py",
|
||||
"nodes_rodin.py",
|
||||
"nodes_gemini.py",
|
||||
"nodes_vidu.py",
|
||||
]
|
||||
|
||||
if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.50"
|
||||
version = "0.3.52"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.25.9
|
||||
comfyui-workflow-templates==0.1.60
|
||||
comfyui-frontend-package==1.25.10
|
||||
comfyui-workflow-templates==0.1.65
|
||||
comfyui-embedded-docs==0.2.6
|
||||
torch
|
||||
torchsde
|
||||
|
||||
Reference in New Issue
Block a user