mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-18 14:10:07 +00:00
Compare commits
3 Commits
feature/fr
...
search-ali
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1202709996 | ||
|
|
dcde86463c | ||
|
|
f02abedcd9 |
@@ -1,23 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy_api.latest._node_replace import NodeReplace
|
||||
|
||||
REGISTERED_NODE_REPLACEMENTS: dict[str, list[NodeReplace]] = {}
|
||||
|
||||
def register_node_replacement(node_replace: NodeReplace):
|
||||
REGISTERED_NODE_REPLACEMENTS.setdefault(node_replace.old_node_id, []).append(node_replace)
|
||||
|
||||
def registered_as_dict():
|
||||
return {
|
||||
k: [v.as_dict() for v in v_list] for k, v_list in REGISTERED_NODE_REPLACEMENTS.items()
|
||||
}
|
||||
|
||||
class NodeReplaceManager:
|
||||
def add_routes(self, routes):
|
||||
@routes.get("/node_replacements")
|
||||
async def get_node_replacements(request):
|
||||
return web.json_response(registered_as_dict())
|
||||
@@ -1,202 +0,0 @@
|
||||
from comfy.ldm.cosmos.predict2 import MiniTrainDIT
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1):
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
x_embed = (x * cos) + (rotate_half(x) * sin)
|
||||
return x_embed
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, head_dim):
|
||||
super().__init__()
|
||||
self.rope_theta = 10000
|
||||
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
inner_dim = head_dim * n_heads
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = head_dim
|
||||
self.query_dim = query_dim
|
||||
self.context_dim = context_dim
|
||||
|
||||
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
|
||||
self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None):
|
||||
context = x if context is None else context
|
||||
input_shape = x.shape[:-1]
|
||||
q_shape = (*input_shape, self.n_heads, self.head_dim)
|
||||
context_shape = context.shape[:-1]
|
||||
kv_shape = (*context_shape, self.n_heads, self.head_dim)
|
||||
|
||||
query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2)
|
||||
key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2)
|
||||
value_states = self.v_proj(context).view(kv_shape).transpose(1, 2)
|
||||
|
||||
if position_embeddings is not None:
|
||||
assert position_embeddings_context is not None
|
||||
cos, sin = position_embeddings
|
||||
query_states = apply_rotary_pos_emb(query_states, cos, sin)
|
||||
cos, sin = position_embeddings_context
|
||||
key_states = apply_rotary_pos_emb(key_states, cos, sin)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
def init_weights(self):
|
||||
torch.nn.init.zeros_(self.o_proj.weight)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.use_self_attn = use_self_attn
|
||||
|
||||
if self.use_self_attn:
|
||||
self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
self.self_attn = Attention(
|
||||
query_dim=model_dim,
|
||||
context_dim=model_dim,
|
||||
n_heads=num_heads,
|
||||
head_dim=model_dim//num_heads,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
self.cross_attn = Attention(
|
||||
query_dim=model_dim,
|
||||
context_dim=source_dim,
|
||||
n_heads=num_heads,
|
||||
head_dim=model_dim//num_heads,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype),
|
||||
nn.GELU(),
|
||||
operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None):
|
||||
if self.use_self_attn:
|
||||
normed = self.norm_self_attn(x)
|
||||
attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings)
|
||||
x = x + attn_out
|
||||
|
||||
normed = self.norm_cross_attn(x)
|
||||
attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
|
||||
x = x + attn_out
|
||||
|
||||
x = x + self.mlp(self.norm_mlp(x))
|
||||
return x
|
||||
|
||||
def init_weights(self):
|
||||
torch.nn.init.zeros_(self.mlp[2].weight)
|
||||
self.cross_attn.init_weights()
|
||||
|
||||
|
||||
class LLMAdapter(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
source_dim=1024,
|
||||
target_dim=1024,
|
||||
model_dim=1024,
|
||||
num_layers=6,
|
||||
num_heads=16,
|
||||
use_self_attn=True,
|
||||
layer_norm=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype)
|
||||
if model_dim != target_dim:
|
||||
self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype)
|
||||
else:
|
||||
self.in_proj = nn.Identity()
|
||||
self.rotary_emb = RotaryEmbedding(model_dim//num_heads)
|
||||
self.blocks = nn.ModuleList([
|
||||
TransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers)
|
||||
])
|
||||
self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype)
|
||||
self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None):
|
||||
if target_attention_mask is not None:
|
||||
target_attention_mask = target_attention_mask.to(torch.bool)
|
||||
if target_attention_mask.ndim == 2:
|
||||
target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
if source_attention_mask is not None:
|
||||
source_attention_mask = source_attention_mask.to(torch.bool)
|
||||
if source_attention_mask.ndim == 2:
|
||||
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
x = self.in_proj(self.embed(target_input_ids))
|
||||
context = source_hidden_states
|
||||
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
||||
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
||||
position_embeddings = self.rotary_emb(x, position_ids)
|
||||
position_embeddings_context = self.rotary_emb(x, position_ids_context)
|
||||
for block in self.blocks:
|
||||
x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
|
||||
return self.norm(self.out_proj(x))
|
||||
|
||||
|
||||
class Anima(MiniTrainDIT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
||||
|
||||
def preprocess_text_embeds(self, text_embeds, text_ids):
|
||||
if text_ids is not None:
|
||||
return self.llm_adapter(text_embeds, text_ids)
|
||||
else:
|
||||
return text_embeds
|
||||
@@ -62,8 +62,6 @@ class WanSelfAttention(nn.Module):
|
||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
def qkv_fn_q(x):
|
||||
@@ -88,10 +86,6 @@ class WanSelfAttention(nn.Module):
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if "attn1_patch" in patches:
|
||||
for p in patches["attn1_patch"]:
|
||||
x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
|
||||
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
@@ -231,8 +225,6 @@ class WanAttentionBlock(nn.Module):
|
||||
"""
|
||||
# assert e.dtype == torch.float32
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
if e.ndim < 4:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||
else:
|
||||
@@ -250,11 +242,6 @@ class WanAttentionBlock(nn.Module):
|
||||
|
||||
# cross-attention & ffn
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
|
||||
if "attn2_patch" in patches:
|
||||
for p in patches["attn2_patch"]:
|
||||
x = p({"x": x, "transformer_options": transformer_options})
|
||||
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
return x
|
||||
@@ -501,7 +488,7 @@ class WanModel(torch.nn.Module):
|
||||
self.blocks = nn.ModuleList([
|
||||
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||
for i in range(num_layers)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
# head
|
||||
@@ -554,7 +541,6 @@ class WanModel(torch.nn.Module):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
@@ -752,7 +738,6 @@ class VaceWanModel(WanModel):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
|
||||
@@ -1,500 +0,0 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
import comfy
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, split_num=8):
|
||||
scale = 1.0 / visual_q.shape[-1] ** 0.5
|
||||
visual_q = visual_q.transpose(1, 2) * scale
|
||||
|
||||
B, H, x_seqlens, K = visual_q.shape
|
||||
|
||||
x_ref_attn_maps = []
|
||||
for class_idx, ref_target_mask in enumerate(ref_target_masks):
|
||||
ref_target_mask = ref_target_mask.view(1, 1, 1, -1)
|
||||
|
||||
x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype)
|
||||
chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens)
|
||||
|
||||
for i in range(0, x_seqlens, chunk_size):
|
||||
end_i = min(i + chunk_size, x_seqlens)
|
||||
|
||||
attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens
|
||||
|
||||
# Apply softmax
|
||||
attn_max = attn_chunk.max(dim=-1, keepdim=True).values
|
||||
attn_chunk = (attn_chunk - attn_max).exp()
|
||||
attn_sum = attn_chunk.sum(dim=-1, keepdim=True)
|
||||
attn_chunk = attn_chunk / (attn_sum + 1e-8)
|
||||
|
||||
# Apply mask and sum
|
||||
masked_attn = attn_chunk * ref_target_mask
|
||||
x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8)
|
||||
|
||||
del attn_chunk, masked_attn
|
||||
|
||||
# Average across heads
|
||||
x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens
|
||||
x_ref_attn_maps.append(x_ref_attnmap)
|
||||
|
||||
del visual_q, ref_k
|
||||
|
||||
return torch.cat(x_ref_attn_maps, dim=0)
|
||||
|
||||
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2):
|
||||
"""Args:
|
||||
query (torch.tensor): B M H K
|
||||
key (torch.tensor): B M H K
|
||||
shape (tuple): (N_t, N_h, N_w)
|
||||
ref_target_masks: [B, N_h * N_w]
|
||||
"""
|
||||
|
||||
N_t, N_h, N_w = shape
|
||||
|
||||
x_seqlens = N_h * N_w
|
||||
ref_k = ref_k[:, :x_seqlens]
|
||||
_, seq_lens, heads, _ = visual_q.shape
|
||||
class_num, _ = ref_target_masks.shape
|
||||
x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q)
|
||||
|
||||
split_chunk = heads // split_num
|
||||
|
||||
for i in range(split_num):
|
||||
x_ref_attn_maps_perhead = calculate_x_ref_attn_map(
|
||||
visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :],
|
||||
ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :],
|
||||
ref_target_masks
|
||||
)
|
||||
x_ref_attn_maps += x_ref_attn_maps_perhead
|
||||
|
||||
return x_ref_attn_maps / split_num
|
||||
|
||||
|
||||
def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
|
||||
source_min, source_max = source_range
|
||||
new_min, new_max = target_range
|
||||
normalized = (column - source_min) / (source_max - source_min + epsilon)
|
||||
scaled = normalized * (new_max - new_min) + new_min
|
||||
return scaled
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||
x1, x2 = x.unbind(dim=-1)
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return rearrange(x, "... d r -> ... (d r)")
|
||||
|
||||
|
||||
def get_audio_embeds(encoded_audio, audio_start, audio_end):
|
||||
audio_embs = []
|
||||
human_num = len(encoded_audio)
|
||||
audio_frames = encoded_audio[0].shape[0]
|
||||
|
||||
indices = (torch.arange(4 + 1) - 2) * 1
|
||||
|
||||
for human_idx in range(human_num):
|
||||
if audio_end > audio_frames: # in case of not enough audio for current window, pad with first audio frame as that's most likely silence
|
||||
pad_len = audio_end - audio_frames
|
||||
pad_shape = list(encoded_audio[human_idx].shape)
|
||||
pad_shape[0] = pad_len
|
||||
pad_tensor = encoded_audio[human_idx][:1].repeat(pad_len, *([1] * (encoded_audio[human_idx].dim() - 1)))
|
||||
encoded_audio_in = torch.cat([encoded_audio[human_idx], pad_tensor], dim=0)
|
||||
else:
|
||||
encoded_audio_in = encoded_audio[human_idx]
|
||||
center_indices = torch.arange(audio_start, audio_end, 1).unsqueeze(1) + indices.unsqueeze(0)
|
||||
center_indices = torch.clamp(center_indices, min=0, max=encoded_audio_in.shape[0] - 1)
|
||||
audio_emb = encoded_audio_in[center_indices].unsqueeze(0)
|
||||
audio_embs.append(audio_emb)
|
||||
|
||||
return torch.cat(audio_embs, dim=0)
|
||||
|
||||
|
||||
def project_audio_features(audio_proj, encoded_audio, audio_start, audio_end):
|
||||
audio_embs = get_audio_embeds(encoded_audio, audio_start, audio_end)
|
||||
|
||||
first_frame_audio_emb_s = audio_embs[:, :1, ...]
|
||||
latter_frame_audio_emb = audio_embs[:, 1:, ...]
|
||||
latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=4)
|
||||
|
||||
middle_index = audio_proj.seq_len // 2
|
||||
|
||||
latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...]
|
||||
latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...]
|
||||
latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...]
|
||||
latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_frame_audio_emb_s = torch.cat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2)
|
||||
|
||||
audio_emb = audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s)
|
||||
audio_emb = torch.cat(audio_emb.split(1), dim=2)
|
||||
|
||||
return audio_emb
|
||||
|
||||
|
||||
class RotaryPositionalEmbedding1D(torch.nn.Module):
|
||||
def __init__(self,
|
||||
head_dim,
|
||||
):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.base = 10000
|
||||
|
||||
def precompute_freqs_cis_1d(self, pos_indices):
|
||||
freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
|
||||
freqs = freqs.to(pos_indices.device)
|
||||
freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs)
|
||||
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
||||
return freqs
|
||||
|
||||
def forward(self, x, pos_indices):
|
||||
freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
|
||||
|
||||
x_ = x.float()
|
||||
|
||||
freqs_cis = freqs_cis.float().to(x.device)
|
||||
cos, sin = freqs_cis.cos(), freqs_cis.sin()
|
||||
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
|
||||
x_ = (x_ * cos) + (rotate_half(x_) * sin)
|
||||
|
||||
return x_.type_as(x)
|
||||
|
||||
class SingleStreamAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
encoder_hidden_states_dim: int,
|
||||
num_heads: int,
|
||||
qkv_bias: bool,
|
||||
device=None, dtype=None, operations=None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.encoder_hidden_states_dim = encoder_hidden_states_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.q_linear = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.kv_linear = operations.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None) -> torch.Tensor:
|
||||
N_t, N_h, N_w = shape
|
||||
|
||||
expected_tokens = N_t * N_h * N_w
|
||||
actual_tokens = x.shape[1]
|
||||
x_extra = None
|
||||
|
||||
if actual_tokens != expected_tokens:
|
||||
x_extra = x[:, -N_h * N_w:, :]
|
||||
x = x[:, :-N_h * N_w, :]
|
||||
N_t = N_t - 1
|
||||
|
||||
B = x.shape[0]
|
||||
S = N_h * N_w
|
||||
x = x.view(B * N_t, S, self.dim)
|
||||
|
||||
# get q for hidden_state
|
||||
q = self.q_linear(x).view(B * N_t, S, self.num_heads, self.head_dim)
|
||||
|
||||
# get kv from encoder_hidden_states # shape: (B, N, num_heads, head_dim)
|
||||
kv = self.kv_linear(encoder_hidden_states)
|
||||
encoder_k, encoder_v = kv.view(B * N_t, encoder_hidden_states.shape[1], 2, self.num_heads, self.head_dim).unbind(2)
|
||||
|
||||
#print("q.shape", q.shape) #torch.Size([21, 1024, 40, 128])
|
||||
x = optimized_attention(
|
||||
q.transpose(1, 2),
|
||||
encoder_k.transpose(1, 2),
|
||||
encoder_v.transpose(1, 2),
|
||||
heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2)
|
||||
|
||||
# linear transform
|
||||
x = self.proj(x.reshape(B * N_t, S, self.dim))
|
||||
x = x.view(B, N_t * S, self.dim)
|
||||
|
||||
if x_extra is not None:
|
||||
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
class SingleStreamMultiAttention(SingleStreamAttention):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
encoder_hidden_states_dim: int,
|
||||
num_heads: int,
|
||||
qkv_bias: bool,
|
||||
class_range: int = 24,
|
||||
class_interval: int = 4,
|
||||
device=None, dtype=None, operations=None
|
||||
) -> None:
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
encoder_hidden_states_dim=encoder_hidden_states_dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations
|
||||
)
|
||||
|
||||
# Rotary-embedding layout parameters
|
||||
self.class_interval = class_interval
|
||||
self.class_range = class_range
|
||||
self.max_humans = self.class_range // self.class_interval
|
||||
|
||||
# Constant bucket used for background tokens
|
||||
self.rope_bak = int(self.class_range // 2)
|
||||
|
||||
self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
shape=None,
|
||||
x_ref_attn_map=None
|
||||
) -> torch.Tensor:
|
||||
encoder_hidden_states = encoder_hidden_states.squeeze(0).to(x.device)
|
||||
human_num = x_ref_attn_map.shape[0] if x_ref_attn_map is not None else 1
|
||||
# Single-speaker fall-through
|
||||
if human_num <= 1:
|
||||
return super().forward(x, encoder_hidden_states, shape)
|
||||
|
||||
N_t, N_h, N_w = shape
|
||||
|
||||
x_extra = None
|
||||
if x.shape[0] * N_t != encoder_hidden_states.shape[0]:
|
||||
x_extra = x[:, -N_h * N_w:, :]
|
||||
x = x[:, :-N_h * N_w, :]
|
||||
N_t = N_t - 1
|
||||
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
|
||||
|
||||
# Query projection
|
||||
B, N, C = x.shape
|
||||
q = self.q_linear(x)
|
||||
q = q.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
||||
|
||||
# Use `class_range` logic for 2 speakers
|
||||
rope_h1 = (0, self.class_interval)
|
||||
rope_h2 = (self.class_range - self.class_interval, self.class_range)
|
||||
rope_bak = int(self.class_range // 2)
|
||||
|
||||
# Normalize and scale attention maps for each speaker
|
||||
max_values = x_ref_attn_map.max(1).values[:, None, None]
|
||||
min_values = x_ref_attn_map.min(1).values[:, None, None]
|
||||
max_min_values = torch.cat([max_values, min_values], dim=2)
|
||||
|
||||
human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min()
|
||||
human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min()
|
||||
|
||||
human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), rope_h1)
|
||||
human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), rope_h2)
|
||||
back = torch.full((x_ref_attn_map.size(1),), rope_bak, dtype=human1.dtype, device=human1.device)
|
||||
|
||||
# Token-wise speaker dominance
|
||||
max_indices = x_ref_attn_map.argmax(dim=0)
|
||||
normalized_map = torch.stack([human1, human2, back], dim=1)
|
||||
normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices]
|
||||
|
||||
# Apply rotary to Q
|
||||
q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
||||
q = self.rope_1d(q, normalized_pos)
|
||||
q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
||||
|
||||
# Keys / Values
|
||||
_, N_a, _ = encoder_hidden_states.shape
|
||||
encoder_kv = self.kv_linear(encoder_hidden_states)
|
||||
encoder_kv = encoder_kv.view(B, N_a, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
encoder_k, encoder_v = encoder_kv.unbind(0)
|
||||
|
||||
# Rotary for keys – assign centre of each speaker bucket to its context tokens
|
||||
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device)
|
||||
per_frame[: per_frame.size(0) // 2] = (rope_h1[0] + rope_h1[1]) / 2
|
||||
per_frame[per_frame.size(0) // 2 :] = (rope_h2[0] + rope_h2[1]) / 2
|
||||
encoder_pos = torch.cat([per_frame] * N_t, dim=0)
|
||||
|
||||
encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
||||
encoder_k = self.rope_1d(encoder_k, encoder_pos)
|
||||
encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
||||
|
||||
# Final attention
|
||||
q = rearrange(q, "B H M K -> B M H K")
|
||||
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
|
||||
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
|
||||
|
||||
x = optimized_attention(
|
||||
q.transpose(1, 2),
|
||||
encoder_k.transpose(1, 2),
|
||||
encoder_v.transpose(1, 2),
|
||||
heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2)
|
||||
|
||||
# Linear projection
|
||||
x = x.reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
|
||||
# Restore original layout
|
||||
x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
|
||||
if x_extra is not None:
|
||||
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MultiTalkAudioProjModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
seq_len: int = 5,
|
||||
seq_len_vf: int = 12,
|
||||
blocks: int = 12,
|
||||
channels: int = 768,
|
||||
intermediate_dim: int = 512,
|
||||
out_dim: int = 768,
|
||||
context_tokens: int = 32,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.seq_len = seq_len
|
||||
self.blocks = blocks
|
||||
self.channels = channels
|
||||
self.input_dim = seq_len * blocks * channels
|
||||
self.input_dim_vf = seq_len_vf * blocks * channels
|
||||
self.intermediate_dim = intermediate_dim
|
||||
self.context_tokens = context_tokens
|
||||
self.out_dim = out_dim
|
||||
|
||||
# define multiple linear layers
|
||||
self.proj1 = operations.Linear(self.input_dim, intermediate_dim, device=device, dtype=dtype)
|
||||
self.proj1_vf = operations.Linear(self.input_dim_vf, intermediate_dim, device=device, dtype=dtype)
|
||||
self.proj2 = operations.Linear(intermediate_dim, intermediate_dim, device=device, dtype=dtype)
|
||||
self.proj3 = operations.Linear(intermediate_dim, context_tokens * out_dim, device=device, dtype=dtype)
|
||||
self.norm = operations.LayerNorm(out_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, audio_embeds, audio_embeds_vf):
|
||||
video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
|
||||
B, _, _, S, C = audio_embeds.shape
|
||||
|
||||
# process audio of first frame
|
||||
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
||||
batch_size, window_size, blocks, channels = audio_embeds.shape
|
||||
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
||||
|
||||
# process audio of latter frame
|
||||
audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
|
||||
batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
|
||||
audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
|
||||
|
||||
# first projection
|
||||
audio_embeds = torch.relu(self.proj1(audio_embeds))
|
||||
audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf))
|
||||
audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
|
||||
audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
|
||||
audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1)
|
||||
batch_size_c, N_t, C_a = audio_embeds_c.shape
|
||||
audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
|
||||
|
||||
# second projection
|
||||
audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
|
||||
|
||||
context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.out_dim)
|
||||
|
||||
# normalization and reshape
|
||||
context_tokens = self.norm(context_tokens)
|
||||
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
|
||||
|
||||
return context_tokens
|
||||
|
||||
|
||||
class WanMultiTalkAttentionBlock(torch.nn.Module):
|
||||
def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.audio_cross_attn = SingleStreamMultiAttention(in_dim, out_dim, num_heads=40, qkv_bias=True, device=device, dtype=dtype, operations=operations)
|
||||
self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True)
|
||||
|
||||
|
||||
class MultiTalkGetAttnMapPatch:
|
||||
def __init__(self, ref_target_masks=None):
|
||||
self.ref_target_masks = ref_target_masks
|
||||
|
||||
def __call__(self, kwargs):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
x = kwargs["x"]
|
||||
|
||||
if self.ref_target_masks is not None:
|
||||
x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device))
|
||||
transformer_options["x_ref_attn_map"] = x_ref_attn_map
|
||||
return x
|
||||
|
||||
|
||||
class MultiTalkCrossAttnPatch:
|
||||
def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None):
|
||||
self.model_patch = model_patch
|
||||
self.audio_scale = audio_scale
|
||||
self.ref_target_masks = ref_target_masks
|
||||
|
||||
def __call__(self, kwargs):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
block_idx = transformer_options.get("block_index", None)
|
||||
x = kwargs["x"]
|
||||
if block_idx is None:
|
||||
return torch.zeros_like(x)
|
||||
|
||||
audio_embeds = transformer_options.get("audio_embeds")
|
||||
x_ref_attn_map = transformer_options.pop("x_ref_attn_map", None)
|
||||
|
||||
norm_x = self.model_patch.model.blocks[block_idx].norm_x(x)
|
||||
x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn(
|
||||
norm_x, audio_embeds.to(x.dtype),
|
||||
shape=transformer_options["grid_sizes"],
|
||||
x_ref_attn_map=x_ref_attn_map
|
||||
)
|
||||
x = x + x_audio * self.audio_scale
|
||||
return x
|
||||
|
||||
def models(self):
|
||||
return [self.model_patch]
|
||||
|
||||
class MultiTalkApplyModelWrapper:
|
||||
def __init__(self, init_latents):
|
||||
self.init_latents = init_latents
|
||||
|
||||
def __call__(self, executor, x, *args, **kwargs):
|
||||
x[:, :, :self.init_latents.shape[2]] = self.init_latents.to(x)
|
||||
samples = executor(x, *args, **kwargs)
|
||||
return samples
|
||||
|
||||
|
||||
class InfiniteTalkOuterSampleWrapper:
|
||||
def __init__(self, motion_frames_latent, model_patch, is_extend=False):
|
||||
self.motion_frames_latent = motion_frames_latent
|
||||
self.model_patch = model_patch
|
||||
self.is_extend = is_extend
|
||||
|
||||
def __call__(self, executor, *args, **kwargs):
|
||||
model_patcher = executor.class_obj.model_patcher
|
||||
model_options = executor.class_obj.model_options
|
||||
process_latent_in = model_patcher.model.process_latent_in
|
||||
|
||||
# for InfiniteTalk, model input first latent(s) need to always be replaced on every step
|
||||
if self.motion_frames_latent is not None:
|
||||
wrappers = model_options["transformer_options"]["wrappers"]
|
||||
w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {})
|
||||
w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(self.motion_frames_latent))]
|
||||
|
||||
# run the sampling process
|
||||
result = executor(*args, **kwargs)
|
||||
|
||||
# insert motion frames before decoding
|
||||
if self.is_extend:
|
||||
overlap = self.motion_frames_latent.shape[2]
|
||||
result = torch.cat([self.motion_frames_latent.to(result), result[:, :, overlap:]], dim=2)
|
||||
|
||||
return result
|
||||
|
||||
def to(self, device_or_dtype):
|
||||
if isinstance(device_or_dtype, torch.device):
|
||||
if self.motion_frames_latent is not None:
|
||||
self.motion_frames_latent = self.motion_frames_latent.to(device_or_dtype)
|
||||
return self
|
||||
@@ -49,7 +49,6 @@ import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -1148,27 +1147,6 @@ class CosmosPredict2(BaseModel):
|
||||
sigma = (sigma / (sigma + 1))
|
||||
return latent_image / (1.0 - sigma)
|
||||
|
||||
class Anima(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.anima.model.Anima)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
t5xxl_ids = kwargs.get("t5xxl_ids", None)
|
||||
t5xxl_weights = kwargs.get("t5xxl_weights", None)
|
||||
device = kwargs["device"]
|
||||
if cross_attn is not None:
|
||||
if t5xxl_ids is not None:
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
|
||||
if t5xxl_weights is not None:
|
||||
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
||||
|
||||
if cross_attn.shape[1] < 512:
|
||||
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class Lumina2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
|
||||
|
||||
@@ -550,8 +550,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "cosmos_predict2"
|
||||
if "{}llm_adapter.blocks.0.cross_attn.q_proj.weight".format(key_prefix) in state_dict_keys:
|
||||
dit_config["image_model"] = "anima"
|
||||
dit_config["max_img_h"] = 240
|
||||
dit_config["max_img_w"] = 240
|
||||
dit_config["max_frames"] = 128
|
||||
|
||||
12
comfy/sd.py
12
comfy/sd.py
@@ -57,7 +57,6 @@ import comfy.text_encoders.ovis
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.jina_clip_2
|
||||
import comfy.text_encoders.newbie
|
||||
import comfy.text_encoders.anima
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -636,13 +635,14 @@ class VAE:
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
|
||||
self.downscale_index_formula = (4, 16, 16)
|
||||
if self.latent_channels in [48, 128]: # Wan 2.2 and LTX2
|
||||
if self.latent_channels == 48: # Wan 2.2
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling
|
||||
self.process_input = self.process_output = lambda image: image
|
||||
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
|
||||
self.process_output = lambda image: image
|
||||
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
|
||||
elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
|
||||
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
|
||||
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
|
||||
else:
|
||||
if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
|
||||
@@ -1048,7 +1048,6 @@ class TEModel(Enum):
|
||||
GEMMA_3_12B = 18
|
||||
JINA_CLIP_2 = 19
|
||||
QWEN3_8B = 20
|
||||
QWEN3_06B = 21
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@@ -1094,8 +1093,6 @@ def detect_te_model(sd):
|
||||
return TEModel.QWEN3_2B
|
||||
elif weight.shape[0] == 4096:
|
||||
return TEModel.QWEN3_8B
|
||||
elif weight.shape[0] == 1024:
|
||||
return TEModel.QWEN3_06B
|
||||
if weight.shape[0] == 5120:
|
||||
if "model.layers.39.post_attention_layernorm.weight" in sd:
|
||||
return TEModel.MISTRAL3_24B
|
||||
@@ -1236,9 +1233,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif te_model == TEModel.JINA_CLIP_2:
|
||||
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
||||
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
||||
elif te_model == TEModel.QWEN3_06B:
|
||||
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
|
||||
else:
|
||||
# clip_l
|
||||
if clip_type == CLIPType.SD3:
|
||||
|
||||
@@ -23,7 +23,6 @@ import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.anima
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -993,36 +992,6 @@ class CosmosT2IPredict2(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
|
||||
|
||||
class Anima(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "anima",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Wan21
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Anima(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect))
|
||||
|
||||
class CosmosI2VPredict2(CosmosT2IPredict2):
|
||||
unet_config = {
|
||||
"image_model": "cosmos_predict2",
|
||||
@@ -1582,6 +1551,6 @@ class Kandinsky5Image(Kandinsky5):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -112,8 +112,7 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
||||
|
||||
|
||||
class TAEHV(nn.Module):
|
||||
def __init__(self, latent_channels, parallel=False, encoder_time_downscale=(True, True, False), decoder_time_upscale=(False, True, True), decoder_space_upscale=(True, True, True),
|
||||
latent_format=None, show_progress_bar=False):
|
||||
def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True):
|
||||
super().__init__()
|
||||
self.image_channels = 3
|
||||
self.patch_size = 1
|
||||
@@ -125,9 +124,6 @@ class TAEHV(nn.Module):
|
||||
self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
|
||||
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
|
||||
self.patch_size = 2
|
||||
elif self.latent_channels == 128: # LTX2
|
||||
self.patch_size, self.latent_channels, encoder_time_downscale, decoder_time_upscale = 4, 128, (True, True, True), (True, True, True)
|
||||
|
||||
if self.latent_channels == 32: # HunyuanVideo1.5
|
||||
act_func = nn.LeakyReLU(0.2, inplace=True)
|
||||
else: # HunyuanVideo, Wan 2.1
|
||||
@@ -135,52 +131,41 @@ class TAEHV(nn.Module):
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
conv(self.image_channels*self.patch_size**2, 64), act_func,
|
||||
TPool(64, 2 if encoder_time_downscale[0] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 2 if encoder_time_downscale[1] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 2 if encoder_time_downscale[2] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
conv(64, self.latent_channels),
|
||||
)
|
||||
n_f = [256, 128, 64, 64]
|
||||
|
||||
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
|
||||
self.decoder = nn.Sequential(
|
||||
Clamp(), conv(self.latent_channels, n_f[0]), act_func,
|
||||
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 2 if decoder_time_upscale[0] else 1), conv(n_f[0], n_f[1], bias=False),
|
||||
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[1] else 1), conv(n_f[1], n_f[2], bias=False),
|
||||
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[2] else 1), conv(n_f[2], n_f[3], bias=False),
|
||||
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
|
||||
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
|
||||
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
|
||||
act_func, conv(n_f[3], self.image_channels*self.patch_size**2),
|
||||
)
|
||||
@property
|
||||
def show_progress_bar(self):
|
||||
return self._show_progress_bar
|
||||
|
||||
self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool))
|
||||
self.t_upscale = 2**sum(t.stride == 2 for t in self.decoder if isinstance(t, TGrow))
|
||||
self.frames_to_trim = self.t_upscale - 1
|
||||
self._show_progress_bar = show_progress_bar
|
||||
|
||||
@property
|
||||
def show_progress_bar(self):
|
||||
return self._show_progress_bar
|
||||
|
||||
@show_progress_bar.setter
|
||||
def show_progress_bar(self, value):
|
||||
self._show_progress_bar = value
|
||||
@show_progress_bar.setter
|
||||
def show_progress_bar(self, value):
|
||||
self._show_progress_bar = value
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
if self.patch_size > 1:
|
||||
B, T, C, H, W = x.shape
|
||||
x = x.reshape(B * T, C, H, W)
|
||||
x = F.pixel_unshuffle(x, self.patch_size)
|
||||
x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size)
|
||||
if x.shape[1] % self.t_downscale != 0:
|
||||
# pad at end to multiple of t_downscale
|
||||
n_pad = self.t_downscale - x.shape[1] % self.t_downscale
|
||||
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
if x.shape[1] % 4 != 0:
|
||||
# pad at end to multiple of 4
|
||||
n_pad = 4 - x.shape[1] % 4
|
||||
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
|
||||
x = torch.cat([x, padding], 1)
|
||||
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
|
||||
return self.process_out(x)
|
||||
|
||||
def decode(self, x, **kwargs):
|
||||
x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W]
|
||||
x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W]
|
||||
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
||||
if self.patch_size > 1:
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
from transformers import Qwen2Tokenizer, T5TokenizerFast
|
||||
import comfy.text_encoders.llama
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
import torch
|
||||
|
||||
|
||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data)
|
||||
|
||||
class AnimaTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.qwen3_06b = Qwen3Tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.t5xxl.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
|
||||
class Qwen3_06BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class AnimaTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen3_06b", clip_model=Qwen3_06BModel, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
out = super().encode_token_weights(token_weight_pairs)
|
||||
out[2]["t5xxl_ids"] = torch.tensor(list(map(lambda a: a[0], token_weight_pairs["t5xxl"][0])), dtype=torch.int)
|
||||
out[2]["t5xxl_weights"] = torch.tensor(list(map(lambda a: a[1], token_weight_pairs["t5xxl"][0])))
|
||||
return out
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class AnimaTEModel_(AnimaTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return AnimaTEModel_
|
||||
@@ -118,18 +118,9 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
|
||||
if len(sdo) == 0:
|
||||
sdo = sd
|
||||
|
||||
missing_all = []
|
||||
unexpected_all = []
|
||||
|
||||
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]:
|
||||
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
|
||||
if component_sd:
|
||||
missing, unexpected = component.load_state_dict(component_sd, strict=False)
|
||||
missing_all.extend([f"{prefix}{k}" for k in missing])
|
||||
unexpected_all.extend([f"{prefix}{k}" for k in unexpected])
|
||||
|
||||
return (missing_all, unexpected_all)
|
||||
missing, unexpected = self.load_state_dict(sdo, strict=False)
|
||||
missing = [k for k in missing if not k.startswith("gemma3_12b.")] # filter out keys that belong to the main gemma model
|
||||
return (missing, unexpected)
|
||||
|
||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||
constant = 6.0
|
||||
|
||||
@@ -10,7 +10,6 @@ from ._input_impl import VideoFromFile, VideoFromComponents
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from . import _io_public as io
|
||||
from . import _ui_public as ui
|
||||
from . import _node_replace_public as node_replace
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||
from PIL import Image
|
||||
@@ -131,5 +130,4 @@ __all__ = [
|
||||
"IO",
|
||||
"ui",
|
||||
"UI",
|
||||
"node_replace",
|
||||
]
|
||||
|
||||
@@ -754,7 +754,7 @@ class AnyType(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="MODEL_PATCH")
|
||||
class ModelPatch(ComfyTypeIO):
|
||||
class MODEL_PATCH(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="AUDIO_ENCODER")
|
||||
@@ -2038,7 +2038,6 @@ __all__ = [
|
||||
"ControlNet",
|
||||
"Vae",
|
||||
"Model",
|
||||
"ModelPatch",
|
||||
"ClipVision",
|
||||
"ClipVisionOutput",
|
||||
"AudioEncoder",
|
||||
|
||||
@@ -1,109 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
import app.node_replace_manager
|
||||
|
||||
def register_node_replacement(node_replace: NodeReplace):
|
||||
"""
|
||||
Register node replacement.
|
||||
"""
|
||||
app.node_replace_manager.register_node_replacement(node_replace)
|
||||
|
||||
|
||||
class NodeReplace:
|
||||
"""
|
||||
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
|
||||
|
||||
Also supports assigning specific values to the input widgets of the new node.
|
||||
"""
|
||||
def __init__(self,
|
||||
new_node_id: str,
|
||||
old_node_id: str,
|
||||
old_widget_ids: list[str] | None=None,
|
||||
input_mapping: list[InputMap] | None=None,
|
||||
output_mapping: list[OutputMap] | None=None,
|
||||
):
|
||||
self.new_node_id = new_node_id
|
||||
self.old_node_id = old_node_id
|
||||
self.old_widget_ids = old_widget_ids
|
||||
self.input_mapping = input_mapping
|
||||
self.output_mapping = output_mapping
|
||||
|
||||
def as_dict(self):
|
||||
"""
|
||||
Create serializable representation of the node replacement.
|
||||
"""
|
||||
return {
|
||||
"new_node_id": self.new_node_id,
|
||||
"old_node_id": self.old_node_id,
|
||||
"old_widget_ids": self.old_widget_ids,
|
||||
"input_mapping": [m.as_dict() for m in self.input_mapping] if self.input_mapping else None,
|
||||
"output_mapping": [m.as_dict() for m in self.output_mapping] if self.output_mapping else None,
|
||||
}
|
||||
|
||||
|
||||
class InputMap:
|
||||
"""
|
||||
Map inputs of node replacement.
|
||||
|
||||
Use InputMap.OldId or InputMap.SetValue for mapping purposes.
|
||||
"""
|
||||
class _Assign:
|
||||
def __init__(self, assign_type: str):
|
||||
self.assign_type = assign_type
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"assign_type": self.assign_type,
|
||||
}
|
||||
|
||||
class OldId(_Assign):
|
||||
"""
|
||||
Connect the input of the old node with given id to new node when replacing.
|
||||
"""
|
||||
def __init__(self, old_id: str):
|
||||
super().__init__("old_id")
|
||||
self.old_id = old_id
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | {
|
||||
"old_id": self.old_id,
|
||||
}
|
||||
|
||||
class SetValue(_Assign):
|
||||
"""
|
||||
Use the given value for the input of the new node when replacing; assumes input is a widget.
|
||||
"""
|
||||
def __init__(self, value: Any):
|
||||
super().__init__("set_value")
|
||||
self.value = value
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | {
|
||||
"value": self.value,
|
||||
}
|
||||
|
||||
def __init__(self, new_id: str, assign: OldId | SetValue):
|
||||
self.new_id = new_id
|
||||
self.assign = assign
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"new_id": self.new_id,
|
||||
"assign": self.assign.as_dict(),
|
||||
}
|
||||
|
||||
|
||||
class OutputMap:
|
||||
"""
|
||||
Map outputs of node replacement via indexes, as that's how outputs are stored.
|
||||
"""
|
||||
def __init__(self, new_idx: int, old_idx: int):
|
||||
self.new_idx = new_idx
|
||||
self.old_idx = old_idx
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"new_idx": self.new_idx,
|
||||
"old_idx": self.old_idx,
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
from ._node_replace import * # noqa: F403
|
||||
@@ -6,7 +6,7 @@ from comfy_api.latest import (
|
||||
)
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest import io, ui, IO, UI, ComfyExtension, node_replace #noqa: F401
|
||||
from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401
|
||||
|
||||
|
||||
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
|
||||
@@ -46,5 +46,4 @@ __all__ = [
|
||||
"IO",
|
||||
"ui",
|
||||
"UI",
|
||||
"node_replace",
|
||||
]
|
||||
|
||||
@@ -24,7 +24,7 @@ class BriaImageEditNode(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BriaImageEditNode",
|
||||
display_name="Bria FIBO Image Edit",
|
||||
display_name="Bria Image Edit",
|
||||
category="api node/image/Bria",
|
||||
description="Edit images using Bria latest model",
|
||||
inputs=[
|
||||
|
||||
@@ -364,9 +364,9 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIGPTImage1",
|
||||
display_name="OpenAI GPT Image 1.5",
|
||||
display_name="OpenAI GPT Image 1",
|
||||
category="api node/image/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
|
||||
description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@@ -429,7 +429,6 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["gpt-image-1", "gpt-image-1.5"],
|
||||
default="gpt-image-1.5",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
|
||||
@@ -28,6 +28,7 @@ class AlignYourStepsScheduler(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="AlignYourStepsScheduler",
|
||||
search_aliases=["AYS scheduler"],
|
||||
category="sampling/custom_sampling/schedulers",
|
||||
inputs=[
|
||||
io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),
|
||||
|
||||
@@ -71,6 +71,7 @@ class CLIPAttentionMultiply(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="CLIPAttentionMultiply",
|
||||
search_aliases=["clip attention scale", "text encoder attention"],
|
||||
category="_for_testing/attention_experiments",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
|
||||
@@ -10,6 +10,7 @@ class Canny(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Canny",
|
||||
search_aliases=["edge detection", "outline", "contour detection", "line art"],
|
||||
category="image/preprocessors",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
|
||||
@@ -38,6 +38,7 @@ class ControlNetInpaintingAliMamaApply(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ControlNetInpaintingAliMamaApply",
|
||||
search_aliases=["masked controlnet"],
|
||||
category="conditioning/controlnet",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
|
||||
@@ -297,6 +297,7 @@ class ExtendIntermediateSigmas(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ExtendIntermediateSigmas",
|
||||
search_aliases=["interpolate sigmas"],
|
||||
category="sampling/custom_sampling/sigmas",
|
||||
inputs=[
|
||||
io.Sigmas.Input("sigmas"),
|
||||
@@ -856,6 +857,7 @@ class DualCFGGuider(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DualCFGGuider",
|
||||
search_aliases=["dual prompt guidance"],
|
||||
category="sampling/custom_sampling/guiders",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
@@ -883,6 +885,7 @@ class DisableNoise(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DisableNoise",
|
||||
search_aliases=["zero noise"],
|
||||
category="sampling/custom_sampling/noise",
|
||||
inputs=[],
|
||||
outputs=[io.Noise.Output()]
|
||||
@@ -1019,6 +1022,7 @@ class ManualSigmas(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ManualSigmas",
|
||||
search_aliases=["custom noise schedule", "define sigmas"],
|
||||
category="_for_testing/custom_sampling",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
|
||||
@@ -1223,11 +1223,11 @@ class ResolutionBucket(io.ComfyNode):
|
||||
|
||||
class MakeTrainingDataset(io.ComfyNode):
|
||||
"""Encode images with VAE and texts with CLIP to create a training dataset."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MakeTrainingDataset",
|
||||
search_aliases=["encode dataset"],
|
||||
display_name="Make Training Dataset",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
@@ -1309,11 +1309,11 @@ class MakeTrainingDataset(io.ComfyNode):
|
||||
|
||||
class SaveTrainingDataset(io.ComfyNode):
|
||||
"""Save encoded training dataset (latents + conditioning) to disk."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveTrainingDataset",
|
||||
search_aliases=["export training data"],
|
||||
display_name="Save Training Dataset",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
@@ -1410,11 +1410,11 @@ class SaveTrainingDataset(io.ComfyNode):
|
||||
|
||||
class LoadTrainingDataset(io.ComfyNode):
|
||||
"""Load encoded training dataset from disk."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadTrainingDataset",
|
||||
search_aliases=["import dataset", "training data"],
|
||||
display_name="Load Training Dataset",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
|
||||
@@ -11,6 +11,7 @@ class DifferentialDiffusion(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DifferentialDiffusion",
|
||||
search_aliases=["inpaint gradient", "variable denoise strength"],
|
||||
display_name="Differential Diffusion",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
|
||||
@@ -29,10 +29,8 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
do_easycache = easycache.should_do_easycache(sigmas)
|
||||
if do_easycache:
|
||||
easycache.check_metadata(x)
|
||||
# if there isn't a cache diff for current conds, we cannot skip this step
|
||||
can_apply_cache_diff = easycache.can_apply_cache_diff(uuids)
|
||||
# if first cond marked this step for skipping, skip it and use appropriate cached values
|
||||
if easycache.skip_current_step and can_apply_cache_diff:
|
||||
if easycache.skip_current_step:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
|
||||
return easycache.apply_cache_diff(x, uuids)
|
||||
@@ -46,7 +44,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
|
||||
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||
easycache.cumulative_change_rate += approx_output_change_rate
|
||||
if easycache.cumulative_change_rate < easycache.reuse_threshold and can_apply_cache_diff:
|
||||
if easycache.cumulative_change_rate < easycache.reuse_threshold:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||
# other conds should also skip this step, and instead use their cached values
|
||||
@@ -242,9 +240,6 @@ class EasyCacheHolder:
|
||||
return to_return.clone()
|
||||
return to_return
|
||||
|
||||
def can_apply_cache_diff(self, uuids: list[UUID]) -> bool:
|
||||
return all(uuid in self.uuid_cache_diffs for uuid in uuids)
|
||||
|
||||
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
|
||||
if self.first_cond_uuid in uuids:
|
||||
self.total_steps_skipped += 1
|
||||
|
||||
@@ -58,6 +58,7 @@ class FreSca(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="FreSca",
|
||||
search_aliases=["frequency guidance"],
|
||||
display_name="FreSca",
|
||||
category="_for_testing",
|
||||
description="Applies frequency-dependent scaling to the guidance",
|
||||
|
||||
@@ -38,6 +38,7 @@ class CLIPTextEncodeHiDream(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeHiDream",
|
||||
search_aliases=["hidream prompt"],
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
|
||||
@@ -259,6 +259,7 @@ class SetClipHooks:
|
||||
return (clip,)
|
||||
|
||||
class ConditioningTimestepsRange:
|
||||
SEARCH_ALIASES = ["prompt scheduling", "timestep segments", "conditioning phases"]
|
||||
NodeId = 'ConditioningTimestepsRange'
|
||||
NodeName = 'Timesteps Range'
|
||||
@classmethod
|
||||
@@ -468,6 +469,7 @@ class SetHookKeyframes:
|
||||
return (hooks,)
|
||||
|
||||
class CreateHookKeyframe:
|
||||
SEARCH_ALIASES = ["hook scheduling", "strength animation", "timed hook"]
|
||||
NodeId = 'CreateHookKeyframe'
|
||||
NodeName = 'Create Hook Keyframe'
|
||||
@classmethod
|
||||
@@ -497,6 +499,7 @@ class CreateHookKeyframe:
|
||||
return (prev_hook_kf,)
|
||||
|
||||
class CreateHookKeyframesInterpolated:
|
||||
SEARCH_ALIASES = ["ease hook strength", "smooth hook transition", "interpolate keyframes"]
|
||||
NodeId = 'CreateHookKeyframesInterpolated'
|
||||
NodeName = 'Create Hook Keyframes Interp.'
|
||||
@classmethod
|
||||
@@ -544,6 +547,7 @@ class CreateHookKeyframesInterpolated:
|
||||
return (prev_hook_kf,)
|
||||
|
||||
class CreateHookKeyframesFromFloats:
|
||||
SEARCH_ALIASES = ["batch keyframes", "strength list to keyframes"]
|
||||
NodeId = 'CreateHookKeyframesFromFloats'
|
||||
NodeName = 'Create Hook Keyframes From Floats'
|
||||
@classmethod
|
||||
@@ -618,6 +622,7 @@ class SetModelHooksOnCond:
|
||||
# Combine Hooks
|
||||
#------------------------------------------
|
||||
class CombineHooks:
|
||||
SEARCH_ALIASES = ["merge hooks"]
|
||||
NodeId = 'CombineHooks2'
|
||||
NodeName = 'Combine Hooks [2]'
|
||||
@classmethod
|
||||
|
||||
@@ -618,6 +618,7 @@ class SaveGLB(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveGLB",
|
||||
search_aliases=["export 3d model", "save mesh"],
|
||||
category="3d",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
|
||||
@@ -104,6 +104,7 @@ class CLIPTextEncodeKandinsky5(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeKandinsky5",
|
||||
search_aliases=["kandinsky prompt"],
|
||||
category="advanced/conditioning/kandinsky5",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
|
||||
@@ -75,6 +75,7 @@ class Preview3D(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Preview3D",
|
||||
search_aliases=["view mesh", "3d viewer"],
|
||||
display_name="Preview 3D & Animation",
|
||||
category="3d",
|
||||
is_experimental=True,
|
||||
|
||||
@@ -224,6 +224,7 @@ class ConvertStringToComboNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ConvertStringToComboNode",
|
||||
search_aliases=["string to dropdown", "text to combo"],
|
||||
display_name="Convert String to Combo",
|
||||
category="logic",
|
||||
inputs=[io.String.Input("string")],
|
||||
@@ -239,6 +240,7 @@ class InvertBooleanNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="InvertBooleanNode",
|
||||
search_aliases=["not", "toggle", "negate", "flip boolean"],
|
||||
display_name="Invert Boolean",
|
||||
category="logic",
|
||||
inputs=[io.Boolean.Input("boolean")],
|
||||
|
||||
@@ -78,6 +78,7 @@ class LoraSave(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoraSave",
|
||||
search_aliases=["export lora"],
|
||||
display_name="Extract and Save Lora",
|
||||
category="_for_testing",
|
||||
inputs=[
|
||||
|
||||
@@ -79,6 +79,7 @@ class CLIPTextEncodeLumina2(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeLumina2",
|
||||
search_aliases=["lumina prompt"],
|
||||
display_name="CLIP Text Encode for Lumina2",
|
||||
category="conditioning",
|
||||
description="Encodes a system prompt and a user prompt using a CLIP model into an embedding "
|
||||
|
||||
@@ -299,6 +299,7 @@ class RescaleCFG:
|
||||
return (m, )
|
||||
|
||||
class ModelComputeDtype:
|
||||
SEARCH_ALIASES = ["model precision", "change dtype"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
|
||||
@@ -91,6 +91,7 @@ class CLIPMergeSimple:
|
||||
|
||||
|
||||
class CLIPSubtract:
|
||||
SEARCH_ALIASES = ["clip difference", "text encoder subtract"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip1": ("CLIP",),
|
||||
@@ -113,6 +114,7 @@ class CLIPSubtract:
|
||||
|
||||
|
||||
class CLIPAdd:
|
||||
SEARCH_ALIASES = ["combine clip"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip1": ("CLIP",),
|
||||
@@ -225,6 +227,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
||||
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
|
||||
|
||||
class CheckpointSave:
|
||||
SEARCH_ALIASES = ["save model", "export checkpoint", "merge save"]
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@@ -337,6 +340,7 @@ class VAESave:
|
||||
return {}
|
||||
|
||||
class ModelSave:
|
||||
SEARCH_ALIASES = ["export model", "checkpoint save"]
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.latent_formats
|
||||
import comfy.ldm.lumina.controlnet
|
||||
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
|
||||
|
||||
|
||||
class BlockWiseControlBlock(torch.nn.Module):
|
||||
@@ -258,14 +257,6 @@ class ModelPatchLoader:
|
||||
if torch.count_nonzero(ref_weight) == 0:
|
||||
config['broken'] = True
|
||||
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
|
||||
elif "audio_proj.proj1.weight" in sd:
|
||||
model = MultiTalkModelPatch(
|
||||
audio_window=5, context_tokens=32, vae_scale=4,
|
||||
in_dim=sd["blocks.0.audio_cross_attn.proj.weight"].shape[0],
|
||||
intermediate_dim=sd["audio_proj.proj1.weight"].shape[0],
|
||||
out_dim=sd["audio_proj.norm.weight"].shape[0],
|
||||
device=comfy.model_management.unet_offload_device(),
|
||||
operations=comfy.ops.manual_cast)
|
||||
|
||||
model.load_state_dict(sd)
|
||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
@@ -533,38 +524,6 @@ class USOStyleReference:
|
||||
return (model_patched,)
|
||||
|
||||
|
||||
class MultiTalkModelPatch(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
audio_window: int = 5,
|
||||
intermediate_dim: int = 512,
|
||||
in_dim: int = 5120,
|
||||
out_dim: int = 768,
|
||||
context_tokens: int = 32,
|
||||
vae_scale: int = 4,
|
||||
num_layers: int = 40,
|
||||
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.audio_proj = MultiTalkAudioProjModel(
|
||||
seq_len=audio_window,
|
||||
seq_len_vf=audio_window+vae_scale-1,
|
||||
intermediate_dim=intermediate_dim,
|
||||
out_dim=out_dim,
|
||||
context_tokens=context_tokens,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations
|
||||
)
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[
|
||||
WanMultiTalkAttentionBlock(in_dim, out_dim, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelPatchLoader": ModelPatchLoader,
|
||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||
|
||||
@@ -7,6 +7,7 @@ class CLIPTextEncodePixArtAlpha(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodePixArtAlpha",
|
||||
search_aliases=["pixart prompt"],
|
||||
category="advanced/conditioning",
|
||||
description="Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma.",
|
||||
inputs=[
|
||||
|
||||
@@ -637,97 +637,6 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
|
||||
batched = batch_masks(values)
|
||||
return io.NodeOutput(batched)
|
||||
|
||||
|
||||
from comfy_api.latest import node_replace
|
||||
|
||||
def register_replacements():
|
||||
register_replacements_longeredge()
|
||||
register_replacements_batchimages()
|
||||
register_replacements_upscaleimage()
|
||||
register_replacements_controlnet()
|
||||
register_replacements_load3d()
|
||||
register_replacements_preview3d()
|
||||
register_replacements_svdimg2vid()
|
||||
register_replacements_conditioningavg()
|
||||
|
||||
def register_replacements_longeredge():
|
||||
# No dynamic inputs here
|
||||
node_replace.register_node_replacement(node_replace.NodeReplace(
|
||||
new_node_id="ImageScaleToMaxDimension",
|
||||
old_node_id="ResizeImagesByLongerEdge",
|
||||
old_widget_ids=["longer_edge"],
|
||||
input_mapping=[
|
||||
node_replace.InputMap(new_id="image", assign=node_replace.InputMap.OldId("images")),
|
||||
node_replace.InputMap(new_id="largest_size", assign=node_replace.InputMap.OldId("longer_edge")),
|
||||
node_replace.InputMap(new_id="upscale_method", assign=node_replace.InputMap.SetValue("lanczos")),
|
||||
],
|
||||
# just to test the frontend output_mapping code, does nothing really here
|
||||
output_mapping=[node_replace.OutputMap(new_idx=0, old_idx=0)],
|
||||
))
|
||||
|
||||
def register_replacements_batchimages():
|
||||
# BatchImages node uses Autogrow
|
||||
node_replace.register_node_replacement(node_replace.NodeReplace(
|
||||
new_node_id="BatchImagesNode",
|
||||
old_node_id="ImageBatch",
|
||||
input_mapping=[
|
||||
node_replace.InputMap(new_id="images.image0", assign=node_replace.InputMap.OldId("image1")),
|
||||
node_replace.InputMap(new_id="images.image1", assign=node_replace.InputMap.OldId("image2")),
|
||||
],
|
||||
))
|
||||
|
||||
def register_replacements_upscaleimage():
|
||||
# ResizeImageMaskNode uses DynamicCombo
|
||||
node_replace.register_node_replacement(node_replace.NodeReplace(
|
||||
new_node_id="ResizeImageMaskNode",
|
||||
old_node_id="ImageScaleBy",
|
||||
old_widget_ids=["upscale_method", "scale_by"],
|
||||
input_mapping=[
|
||||
node_replace.InputMap(new_id="input", assign=node_replace.InputMap.OldId("image")),
|
||||
node_replace.InputMap(new_id="resize_type", assign=node_replace.InputMap.SetValue("scale by multiplier")),
|
||||
node_replace.InputMap(new_id="resize_type.multiplier", assign=node_replace.InputMap.OldId("scale_by")),
|
||||
node_replace.InputMap(new_id="scale_method", assign=node_replace.InputMap.OldId("upscale_method")),
|
||||
],
|
||||
))
|
||||
|
||||
def register_replacements_controlnet():
|
||||
# T2IAdapterLoader → ControlNetLoader
|
||||
node_replace.register_node_replacement(node_replace.NodeReplace(
|
||||
new_node_id="ControlNetLoader",
|
||||
old_node_id="T2IAdapterLoader",
|
||||
input_mapping=[
|
||||
node_replace.InputMap(new_id="control_net_name", assign=node_replace.InputMap.OldId("t2i_adapter_name")),
|
||||
],
|
||||
))
|
||||
|
||||
def register_replacements_load3d():
|
||||
# Load3DAnimation merged into Load3D
|
||||
node_replace.register_node_replacement(node_replace.NodeReplace(
|
||||
new_node_id="Load3D",
|
||||
old_node_id="Load3DAnimation",
|
||||
))
|
||||
|
||||
def register_replacements_preview3d():
|
||||
# Preview3DAnimation merged into Preview3D
|
||||
node_replace.register_node_replacement(node_replace.NodeReplace(
|
||||
new_node_id="Preview3D",
|
||||
old_node_id="Preview3DAnimation",
|
||||
))
|
||||
|
||||
def register_replacements_svdimg2vid():
|
||||
# Typo fix: SDV → SVD
|
||||
node_replace.register_node_replacement(node_replace.NodeReplace(
|
||||
new_node_id="SVD_img2vid_Conditioning",
|
||||
old_node_id="SDV_img2vid_Conditioning",
|
||||
))
|
||||
|
||||
def register_replacements_conditioningavg():
|
||||
# Typo fix: trailing space in node name
|
||||
node_replace.register_node_replacement(node_replace.NodeReplace(
|
||||
new_node_id="ConditioningAverage",
|
||||
old_node_id="ConditioningAverage ",
|
||||
))
|
||||
|
||||
class PostProcessingExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
|
||||
@@ -16,7 +16,7 @@ class PreviewAny():
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "utils"
|
||||
SEARCH_ALIASES = ["preview", "show", "display", "view", "show text", "display text", "preview text", "show output", "inspect", "debug"]
|
||||
SEARCH_ALIASES = ["show output", "inspect", "debug", "print value", "show text"]
|
||||
|
||||
def main(self, source=None):
|
||||
value = 'None'
|
||||
|
||||
@@ -65,6 +65,7 @@ class CLIPTextEncodeSD3(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeSD3",
|
||||
search_aliases=["sd3 prompt"],
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
|
||||
@@ -1101,6 +1101,7 @@ class SaveLoRA(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveLoRA",
|
||||
search_aliases=["export lora"],
|
||||
display_name="Save LoRA Weights",
|
||||
category="loaders",
|
||||
is_experimental=True,
|
||||
@@ -1144,6 +1145,7 @@ class LossGraphNode(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LossGraphNode",
|
||||
search_aliases=["training chart", "training visualization", "plot loss"],
|
||||
display_name="Plot Loss Graph",
|
||||
category="training",
|
||||
is_experimental=True,
|
||||
|
||||
@@ -8,10 +8,9 @@ import comfy.latent_formats
|
||||
import comfy.clip_vision
|
||||
import json
|
||||
import numpy as np
|
||||
from typing import Tuple, TypedDict
|
||||
from typing import Tuple
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import logging
|
||||
|
||||
class WanImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -1289,171 +1288,6 @@ class Wan22ImageToVideoLatent(io.ComfyNode):
|
||||
return io.NodeOutput(out_latent)
|
||||
|
||||
|
||||
from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, MultiTalkGetAttnMapPatch, project_audio_features
|
||||
class WanInfiniteTalkToVideo(io.ComfyNode):
|
||||
class DCValues(TypedDict):
|
||||
mode: str
|
||||
audio_encoder_output_2: io.AudioEncoderOutput.Type
|
||||
mask: io.Mask.Type
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanInfiniteTalkToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.DynamicCombo.Input("mode", options=[
|
||||
io.DynamicCombo.Option("single_speaker", []),
|
||||
io.DynamicCombo.Option("two_speakers", [
|
||||
io.AudioEncoderOutput.Input("audio_encoder_output_2", optional=True),
|
||||
io.Mask.Input("mask_1", optional=True, tooltip="Mask for the first speaker, required if using two audio inputs."),
|
||||
io.Mask.Input("mask_2", optional=True, tooltip="Mask for the second speaker, required if using two audio inputs."),
|
||||
]),
|
||||
]),
|
||||
io.Model.Input("model"),
|
||||
io.ModelPatch.Input("model_patch"),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
io.AudioEncoderOutput.Input("audio_encoder_output_1"),
|
||||
io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context."),
|
||||
io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||
io.Image.Input("previous_frames", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name="model"),
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
io.Int.Output(display_name="trim_image"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mode: DCValues, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count,
|
||||
start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput:
|
||||
|
||||
if previous_frames is not None and previous_frames.shape[0] < motion_frame_count:
|
||||
raise ValueError("Not enough previous frames provided.")
|
||||
|
||||
if mode["mode"] == "two_speakers":
|
||||
audio_encoder_output_2 = mode["audio_encoder_output_2"]
|
||||
mask_1 = mode["mask_1"]
|
||||
mask_2 = mode["mask_2"]
|
||||
|
||||
if audio_encoder_output_2 is not None:
|
||||
if mask_1 is None or mask_2 is None:
|
||||
raise ValueError("Masks must be provided if two audio encoder outputs are used.")
|
||||
|
||||
ref_masks = None
|
||||
if mask_1 is not None and mask_2 is not None:
|
||||
if audio_encoder_output_2 is None:
|
||||
raise ValueError("Second audio encoder output must be provided if two masks are used.")
|
||||
ref_masks = torch.cat([mask_1, mask_2])
|
||||
|
||||
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
|
||||
image[:start_image.shape[0]] = start_image
|
||||
|
||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||
concat_mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
concat_mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
model_patched = model.clone()
|
||||
|
||||
encoded_audio_list = []
|
||||
seq_lengths = []
|
||||
|
||||
for audio_encoder_output in [audio_encoder_output_1, audio_encoder_output_2]:
|
||||
if audio_encoder_output is None:
|
||||
continue
|
||||
all_layers = audio_encoder_output["encoded_audio_all_layers"]
|
||||
encoded_audio = torch.stack(all_layers, dim=0).squeeze(1)[1:] # shape: [num_layers, T, 512]
|
||||
encoded_audio = linear_interpolation(encoded_audio, input_fps=50, output_fps=25).movedim(0, 1) # shape: [T, num_layers, 512]
|
||||
encoded_audio_list.append(encoded_audio)
|
||||
seq_lengths.append(encoded_audio.shape[0])
|
||||
|
||||
# Pad / combine depending on multi_audio_type
|
||||
multi_audio_type = "add"
|
||||
if len(encoded_audio_list) > 1:
|
||||
if multi_audio_type == "para":
|
||||
max_len = max(seq_lengths)
|
||||
padded = []
|
||||
for emb in encoded_audio_list:
|
||||
if emb.shape[0] < max_len:
|
||||
pad = torch.zeros(max_len - emb.shape[0], *emb.shape[1:], dtype=emb.dtype)
|
||||
emb = torch.cat([emb, pad], dim=0)
|
||||
padded.append(emb)
|
||||
encoded_audio_list = padded
|
||||
elif multi_audio_type == "add":
|
||||
total_len = sum(seq_lengths)
|
||||
full_list = []
|
||||
offset = 0
|
||||
for emb, seq_len in zip(encoded_audio_list, seq_lengths):
|
||||
full = torch.zeros(total_len, *emb.shape[1:], dtype=emb.dtype)
|
||||
full[offset:offset+seq_len] = emb
|
||||
full_list.append(full)
|
||||
offset += seq_len
|
||||
encoded_audio_list = full_list
|
||||
|
||||
token_ref_target_masks = None
|
||||
if ref_masks is not None:
|
||||
token_ref_target_masks = torch.nn.functional.interpolate(
|
||||
ref_masks.unsqueeze(0), size=(latent.shape[-2] // 2, latent.shape[-1] // 2), mode='nearest')[0]
|
||||
token_ref_target_masks = (token_ref_target_masks > 0).view(token_ref_target_masks.shape[0], -1)
|
||||
|
||||
# when extending from previous frames
|
||||
if previous_frames is not None:
|
||||
motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
frame_offset = previous_frames.shape[0] - motion_frame_count
|
||||
|
||||
audio_start = frame_offset
|
||||
audio_end = audio_start + length
|
||||
logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}")
|
||||
|
||||
motion_frames_latent = vae.encode(motion_frames[:, :, :, :3])
|
||||
trim_image = motion_frame_count
|
||||
else:
|
||||
audio_start = trim_image = 0
|
||||
audio_end = length
|
||||
motion_frames_latent = concat_latent_image[:, :, :1]
|
||||
|
||||
audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype())
|
||||
model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed
|
||||
|
||||
# add outer sample wrapper
|
||||
model_patched.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.OUTER_SAMPLE,
|
||||
"infinite_talk_outer_sample",
|
||||
InfiniteTalkOuterSampleWrapper(
|
||||
motion_frames_latent,
|
||||
model_patch,
|
||||
is_extend=previous_frames is not None,
|
||||
))
|
||||
# add cross-attention patch
|
||||
model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "attn2_patch")
|
||||
if token_ref_target_masks is not None:
|
||||
model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "attn1_patch")
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
|
||||
|
||||
|
||||
class WanExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@@ -1473,7 +1307,6 @@ class WanExtension(ComfyExtension):
|
||||
WanHuMoImageToVideo,
|
||||
WanAnimateToVideo,
|
||||
Wan22ImageToVideoLatent,
|
||||
WanInfiniteTalkToVideo,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> WanExtension:
|
||||
|
||||
@@ -324,6 +324,7 @@ class GenerateTracks(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="GenerateTracks",
|
||||
search_aliases=["motion paths", "camera movement", "trajectory"],
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=832, min=16, max=4096, step=16),
|
||||
|
||||
@@ -5,6 +5,7 @@ MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||
|
||||
|
||||
class WebcamCapture(nodes.LoadImage):
|
||||
SEARCH_ALIASES = ["camera input", "live capture", "camera feed", "snapshot"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
|
||||
@@ -11,7 +11,7 @@ import logging
|
||||
default_preview_method = args.preview_method
|
||||
|
||||
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
|
||||
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
|
||||
|
||||
def preview_to_image(latent_image, do_scale=True):
|
||||
if do_scale:
|
||||
|
||||
2
nodes.py
2
nodes.py
@@ -707,7 +707,7 @@ class LoraLoaderModelOnly(LoraLoader):
|
||||
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
|
||||
|
||||
class VAELoader:
|
||||
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
|
||||
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
|
||||
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
|
||||
@staticmethod
|
||||
def vae_list(s):
|
||||
|
||||
@@ -40,7 +40,6 @@ from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from app.subgraph_manager import SubgraphManager
|
||||
from app.node_replace_manager import NodeReplaceManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from protocol import BinaryEventTypes
|
||||
@@ -205,7 +204,6 @@ class PromptServer():
|
||||
self.model_file_manager = ModelFileManager()
|
||||
self.custom_node_manager = CustomNodeManager()
|
||||
self.subgraph_manager = SubgraphManager()
|
||||
self.node_replace_manager = NodeReplaceManager()
|
||||
self.internal_routes = InternalRoutes(self)
|
||||
self.supports = ["custom_nodes_from_web"]
|
||||
self.prompt_queue = execution.PromptQueue(self)
|
||||
@@ -994,7 +992,6 @@ class PromptServer():
|
||||
self.model_file_manager.add_routes(self.routes)
|
||||
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.node_replace_manager.add_routes(self.routes)
|
||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||
|
||||
# Prefix every route with /api for easier matching for delegation.
|
||||
|
||||
Reference in New Issue
Block a user