|
|
|
|
@@ -34,7 +34,9 @@ class WanSelfAttention(nn.Module):
|
|
|
|
|
num_heads,
|
|
|
|
|
window_size=(-1, -1),
|
|
|
|
|
qk_norm=True,
|
|
|
|
|
eps=1e-6, operation_settings={}):
|
|
|
|
|
eps=1e-6,
|
|
|
|
|
kv_dim=None,
|
|
|
|
|
operation_settings={}):
|
|
|
|
|
assert dim % num_heads == 0
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.dim = dim
|
|
|
|
|
@@ -43,11 +45,13 @@ class WanSelfAttention(nn.Module):
|
|
|
|
|
self.window_size = window_size
|
|
|
|
|
self.qk_norm = qk_norm
|
|
|
|
|
self.eps = eps
|
|
|
|
|
if kv_dim is None:
|
|
|
|
|
kv_dim = dim
|
|
|
|
|
|
|
|
|
|
# layers
|
|
|
|
|
self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
|
|
|
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
|
|
|
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
|
|
|
self.k = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
|
|
|
self.v = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
|
|
|
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
|
|
|
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
|
|
|
|
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
|
|
|
|
@@ -402,6 +406,7 @@ class WanModel(torch.nn.Module):
|
|
|
|
|
eps=1e-6,
|
|
|
|
|
flf_pos_embed_token_number=None,
|
|
|
|
|
in_dim_ref_conv=None,
|
|
|
|
|
wan_attn_block_class=WanAttentionBlock,
|
|
|
|
|
image_model=None,
|
|
|
|
|
device=None,
|
|
|
|
|
dtype=None,
|
|
|
|
|
@@ -479,8 +484,8 @@ class WanModel(torch.nn.Module):
|
|
|
|
|
# blocks
|
|
|
|
|
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
|
|
|
|
self.blocks = nn.ModuleList([
|
|
|
|
|
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
|
|
|
|
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
|
|
|
|
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
|
|
|
|
|
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
|
|
|
|
for _ in range(num_layers)
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
@@ -1325,3 +1330,247 @@ class WanModel_S2V(WanModel):
|
|
|
|
|
# unpatchify
|
|
|
|
|
x = self.unpatchify(x, grid_sizes)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WanT2VCrossAttentionGather(WanSelfAttention):
|
|
|
|
|
|
|
|
|
|
def forward(self, x, context, transformer_options={}, **kwargs):
|
|
|
|
|
r"""
|
|
|
|
|
Args:
|
|
|
|
|
x(Tensor): Shape [B, L1, C] - video tokens
|
|
|
|
|
context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536]
|
|
|
|
|
"""
|
|
|
|
|
b, n, d = x.size(0), self.num_heads, self.head_dim
|
|
|
|
|
|
|
|
|
|
q = self.norm_q(self.q(x))
|
|
|
|
|
k = self.norm_k(self.k(context))
|
|
|
|
|
v = self.v(context)
|
|
|
|
|
|
|
|
|
|
# Handle audio temporal structure (16 tokens per frame)
|
|
|
|
|
k = k.reshape(-1, 16, n, d).transpose(1, 2)
|
|
|
|
|
v = v.reshape(-1, 16, n, d).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
# Handle video spatial structure
|
|
|
|
|
q = q.reshape(k.shape[0], -1, n, d).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
|
|
|
|
|
|
|
|
|
|
x = x.transpose(1, 2).view(b, -1, n, d).flatten(2)
|
|
|
|
|
x = self.o(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AudioCrossAttentionWrapper(nn.Module):
|
|
|
|
|
def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6, operation_settings={}):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm, kv_dim, eps, operation_settings=operation_settings)
|
|
|
|
|
self.norm1_audio = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
|
|
|
|
|
|
|
|
def forward(self, x, audio, transformer_options={}):
|
|
|
|
|
x = x + self.audio_cross_attn(self.norm1_audio(x), audio, transformer_options=transformer_options)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WanAttentionBlockAudio(WanAttentionBlock):
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
cross_attn_type,
|
|
|
|
|
dim,
|
|
|
|
|
ffn_dim,
|
|
|
|
|
num_heads,
|
|
|
|
|
window_size=(-1, -1),
|
|
|
|
|
qk_norm=True,
|
|
|
|
|
cross_attn_norm=False,
|
|
|
|
|
eps=1e-6, operation_settings={}):
|
|
|
|
|
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings)
|
|
|
|
|
self.audio_cross_attn_wrapper = AudioCrossAttentionWrapper(dim, 1536, num_heads, qk_norm, eps, operation_settings=operation_settings)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
x,
|
|
|
|
|
e,
|
|
|
|
|
freqs,
|
|
|
|
|
context,
|
|
|
|
|
context_img_len=257,
|
|
|
|
|
audio=None,
|
|
|
|
|
transformer_options={},
|
|
|
|
|
):
|
|
|
|
|
r"""
|
|
|
|
|
Args:
|
|
|
|
|
x(Tensor): Shape [B, L, C]
|
|
|
|
|
e(Tensor): Shape [B, 6, C]
|
|
|
|
|
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
|
|
|
|
"""
|
|
|
|
|
# assert e.dtype == torch.float32
|
|
|
|
|
|
|
|
|
|
if e.ndim < 4:
|
|
|
|
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
|
|
|
|
else:
|
|
|
|
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
|
|
|
|
|
# assert e[0].dtype == torch.float32
|
|
|
|
|
|
|
|
|
|
# self-attention
|
|
|
|
|
y = self.self_attn(
|
|
|
|
|
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
|
|
|
|
freqs, transformer_options=transformer_options)
|
|
|
|
|
|
|
|
|
|
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
|
|
|
|
|
|
|
|
|
# cross-attention & ffn
|
|
|
|
|
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
|
|
|
|
if audio is not None:
|
|
|
|
|
x = self.audio_cross_attn_wrapper(x, audio, transformer_options=transformer_options)
|
|
|
|
|
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
|
|
|
|
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
class DummyAdapterLayer(nn.Module):
|
|
|
|
|
def __init__(self, layer):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.layer = layer
|
|
|
|
|
|
|
|
|
|
def forward(self, *args, **kwargs):
|
|
|
|
|
return self.layer(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AudioProjModel(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
seq_len=5,
|
|
|
|
|
blocks=13, # add a new parameter blocks
|
|
|
|
|
channels=768, # add a new parameter channels
|
|
|
|
|
intermediate_dim=512,
|
|
|
|
|
output_dim=1536,
|
|
|
|
|
context_tokens=16,
|
|
|
|
|
device=None,
|
|
|
|
|
dtype=None,
|
|
|
|
|
operations=None,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.seq_len = seq_len
|
|
|
|
|
self.blocks = blocks
|
|
|
|
|
self.channels = channels
|
|
|
|
|
self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels.
|
|
|
|
|
self.intermediate_dim = intermediate_dim
|
|
|
|
|
self.context_tokens = context_tokens
|
|
|
|
|
self.output_dim = output_dim
|
|
|
|
|
|
|
|
|
|
# define multiple linear layers
|
|
|
|
|
self.audio_proj_glob_1 = DummyAdapterLayer(operations.Linear(self.input_dim, intermediate_dim, dtype=dtype, device=device))
|
|
|
|
|
self.audio_proj_glob_2 = DummyAdapterLayer(operations.Linear(intermediate_dim, intermediate_dim, dtype=dtype, device=device))
|
|
|
|
|
self.audio_proj_glob_3 = DummyAdapterLayer(operations.Linear(intermediate_dim, context_tokens * output_dim, dtype=dtype, device=device))
|
|
|
|
|
|
|
|
|
|
self.audio_proj_glob_norm = DummyAdapterLayer(operations.LayerNorm(output_dim, dtype=dtype, device=device))
|
|
|
|
|
|
|
|
|
|
def forward(self, audio_embeds):
|
|
|
|
|
video_length = audio_embeds.shape[1]
|
|
|
|
|
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
|
|
|
|
batch_size, window_size, blocks, channels = audio_embeds.shape
|
|
|
|
|
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
|
|
|
|
|
|
|
|
|
audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds))
|
|
|
|
|
audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds))
|
|
|
|
|
|
|
|
|
|
context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim)
|
|
|
|
|
|
|
|
|
|
context_tokens = self.audio_proj_glob_norm(context_tokens)
|
|
|
|
|
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
|
|
|
|
|
|
|
|
|
|
return context_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HumoWanModel(WanModel):
|
|
|
|
|
r"""
|
|
|
|
|
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
model_type='humo',
|
|
|
|
|
patch_size=(1, 2, 2),
|
|
|
|
|
text_len=512,
|
|
|
|
|
in_dim=16,
|
|
|
|
|
dim=2048,
|
|
|
|
|
ffn_dim=8192,
|
|
|
|
|
freq_dim=256,
|
|
|
|
|
text_dim=4096,
|
|
|
|
|
out_dim=16,
|
|
|
|
|
num_heads=16,
|
|
|
|
|
num_layers=32,
|
|
|
|
|
window_size=(-1, -1),
|
|
|
|
|
qk_norm=True,
|
|
|
|
|
cross_attn_norm=True,
|
|
|
|
|
eps=1e-6,
|
|
|
|
|
flf_pos_embed_token_number=None,
|
|
|
|
|
image_model=None,
|
|
|
|
|
audio_token_num=16,
|
|
|
|
|
device=None,
|
|
|
|
|
dtype=None,
|
|
|
|
|
operations=None,
|
|
|
|
|
):
|
|
|
|
|
|
|
|
|
|
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
|
|
|
|
|
|
|
|
|
self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)
|
|
|
|
|
|
|
|
|
|
def forward_orig(
|
|
|
|
|
self,
|
|
|
|
|
x,
|
|
|
|
|
t,
|
|
|
|
|
context,
|
|
|
|
|
freqs=None,
|
|
|
|
|
audio_embed=None,
|
|
|
|
|
reference_latent=None,
|
|
|
|
|
transformer_options={},
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
bs, _, time, height, width = x.shape
|
|
|
|
|
|
|
|
|
|
# embeddings
|
|
|
|
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
|
|
|
|
grid_sizes = x.shape[2:]
|
|
|
|
|
x = x.flatten(2).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
# time embeddings
|
|
|
|
|
e = self.time_embedding(
|
|
|
|
|
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
|
|
|
|
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
|
|
|
|
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
|
|
|
|
|
|
|
|
|
if reference_latent is not None:
|
|
|
|
|
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
|
|
|
|
|
ref = ref.flatten(2).transpose(1, 2)
|
|
|
|
|
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=time, device=x.device, dtype=x.dtype)
|
|
|
|
|
x = torch.cat([x, ref], dim=1)
|
|
|
|
|
freqs = torch.cat([freqs, freqs_ref], dim=1)
|
|
|
|
|
del ref, freqs_ref
|
|
|
|
|
|
|
|
|
|
# context
|
|
|
|
|
context = self.text_embedding(context)
|
|
|
|
|
context_img_len = None
|
|
|
|
|
|
|
|
|
|
if audio_embed is not None:
|
|
|
|
|
audio = self.audio_proj(audio_embed).permute(0, 3, 1, 2).flatten(2).transpose(1, 2)
|
|
|
|
|
else:
|
|
|
|
|
audio = None
|
|
|
|
|
|
|
|
|
|
patches_replace = transformer_options.get("patches_replace", {})
|
|
|
|
|
blocks_replace = patches_replace.get("dit", {})
|
|
|
|
|
for i, block in enumerate(self.blocks):
|
|
|
|
|
if ("double_block", i) in blocks_replace:
|
|
|
|
|
def block_wrap(args):
|
|
|
|
|
out = {}
|
|
|
|
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, audio=audio, transformer_options=args["transformer_options"])
|
|
|
|
|
return out
|
|
|
|
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
|
|
|
|
x = out["img"]
|
|
|
|
|
else:
|
|
|
|
|
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, audio=audio, transformer_options=transformer_options)
|
|
|
|
|
|
|
|
|
|
# head
|
|
|
|
|
x = self.head(x, e)
|
|
|
|
|
|
|
|
|
|
# unpatchify
|
|
|
|
|
x = self.unpatchify(x, grid_sizes)
|
|
|
|
|
return x
|
|
|
|
|
|