mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-04 20:59:58 +00:00
Merge branch 'master' into worksplit-multigpu
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from .wav2vec2 import Wav2Vec2Model
|
||||
from .whisper import WhisperLargeV3
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
import comfy.utils
|
||||
@@ -11,7 +12,18 @@ class AudioEncoderModel():
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast)
|
||||
model_type = config.pop("model_type")
|
||||
model_config = dict(config)
|
||||
model_config.update({
|
||||
"dtype": self.dtype,
|
||||
"device": offload_device,
|
||||
"operations": comfy.ops.manual_cast
|
||||
})
|
||||
|
||||
if model_type == "wav2vec2":
|
||||
self.model = Wav2Vec2Model(**model_config)
|
||||
elif model_type == "whisper3":
|
||||
self.model = WhisperLargeV3(**model_config)
|
||||
self.model.eval()
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.model_sample_rate = 16000
|
||||
@@ -29,14 +41,51 @@ class AudioEncoderModel():
|
||||
outputs = {}
|
||||
outputs["encoded_audio"] = out
|
||||
outputs["encoded_audio_all_layers"] = all_layers
|
||||
outputs["audio_samples"] = audio.shape[2]
|
||||
return outputs
|
||||
|
||||
|
||||
def load_audio_encoder_from_sd(sd, prefix=""):
|
||||
audio_encoder = AudioEncoderModel(None)
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
|
||||
if "encoder.layer_norm.bias" in sd: #wav2vec2
|
||||
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
|
||||
if embed_dim == 1024:# large
|
||||
config = {
|
||||
"model_type": "wav2vec2",
|
||||
"embed_dim": 1024,
|
||||
"num_heads": 16,
|
||||
"num_layers": 24,
|
||||
"conv_norm": True,
|
||||
"conv_bias": True,
|
||||
"do_normalize": True,
|
||||
"do_stable_layer_norm": True
|
||||
}
|
||||
elif embed_dim == 768: # base
|
||||
config = {
|
||||
"model_type": "wav2vec2",
|
||||
"embed_dim": 768,
|
||||
"num_heads": 12,
|
||||
"num_layers": 12,
|
||||
"conv_norm": False,
|
||||
"conv_bias": False,
|
||||
"do_normalize": False, # chinese-wav2vec2-base has this False
|
||||
"do_stable_layer_norm": False
|
||||
}
|
||||
else:
|
||||
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
|
||||
elif "model.encoder.embed_positions.weight" in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
|
||||
config = {
|
||||
"model_type": "whisper3",
|
||||
}
|
||||
else:
|
||||
raise RuntimeError("ERROR: audio encoder not supported.")
|
||||
|
||||
audio_encoder = AudioEncoderModel(config)
|
||||
m, u = audio_encoder.load_sd(sd)
|
||||
if len(m) > 0:
|
||||
logging.warning("missing audio encoder: {}".format(m))
|
||||
if len(u) > 0:
|
||||
logging.warning("unexpected audio encoder: {}".format(u))
|
||||
|
||||
return audio_encoder
|
||||
|
||||
@@ -13,19 +13,49 @@ class LayerNormConv(nn.Module):
|
||||
x = self.conv(x)
|
||||
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
|
||||
|
||||
class LayerGroupNormConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
||||
self.layer_norm = operations.GroupNorm(num_groups=out_channels, num_channels=out_channels, affine=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return torch.nn.functional.gelu(self.layer_norm(x))
|
||||
|
||||
class ConvNoNorm(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return torch.nn.functional.gelu(x)
|
||||
|
||||
|
||||
class ConvFeatureEncoder(nn.Module):
|
||||
def __init__(self, conv_dim, dtype=None, device=None, operations=None):
|
||||
def __init__(self, conv_dim, conv_bias=False, conv_norm=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv_layers = nn.ModuleList([
|
||||
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
])
|
||||
if conv_norm:
|
||||
self.conv_layers = nn.ModuleList([
|
||||
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
])
|
||||
else:
|
||||
self.conv_layers = nn.ModuleList([
|
||||
LayerGroupNormConv(1, conv_dim, kernel_size=10, stride=5, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
x = x.unsqueeze(1)
|
||||
@@ -76,6 +106,7 @@ class TransformerEncoder(nn.Module):
|
||||
num_heads=12,
|
||||
num_layers=12,
|
||||
mlp_ratio=4.0,
|
||||
do_stable_layer_norm=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
@@ -86,20 +117,25 @@ class TransformerEncoder(nn.Module):
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
do_stable_layer_norm=do_stable_layer_norm,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
x = x + self.pos_conv_embed(x)
|
||||
all_x = ()
|
||||
if not self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
for layer in self.layers:
|
||||
all_x += (x,)
|
||||
x = layer(x, mask)
|
||||
x = self.layer_norm(x)
|
||||
if self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
all_x += (x,)
|
||||
return x, all_x
|
||||
|
||||
@@ -145,6 +181,7 @@ class TransformerEncoderLayer(nn.Module):
|
||||
embed_dim=768,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
do_stable_layer_norm=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
@@ -154,15 +191,19 @@ class TransformerEncoderLayer(nn.Module):
|
||||
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
residual = x
|
||||
x = self.layer_norm(x)
|
||||
if self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
x = self.attention(x, mask=mask)
|
||||
x = residual + x
|
||||
|
||||
x = x + self.feed_forward(self.final_layer_norm(x))
|
||||
return x
|
||||
if not self.do_stable_layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
return self.final_layer_norm(x + self.feed_forward(x))
|
||||
else:
|
||||
return x + self.feed_forward(self.final_layer_norm(x))
|
||||
|
||||
|
||||
class Wav2Vec2Model(nn.Module):
|
||||
@@ -174,34 +215,38 @@ class Wav2Vec2Model(nn.Module):
|
||||
final_dim=256,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
conv_norm=True,
|
||||
conv_bias=True,
|
||||
do_normalize=True,
|
||||
do_stable_layer_norm=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
conv_dim = 512
|
||||
self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations)
|
||||
self.feature_extractor = ConvFeatureEncoder(conv_dim, conv_norm=conv_norm, conv_bias=conv_bias, device=device, dtype=dtype, operations=operations)
|
||||
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
self.encoder = TransformerEncoder(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
num_layers=num_layers,
|
||||
do_stable_layer_norm=do_stable_layer_norm,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
|
||||
def forward(self, x, mask_time_indices=None, return_dict=False):
|
||||
|
||||
x = torch.mean(x, dim=1)
|
||||
|
||||
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
|
||||
if self.do_normalize:
|
||||
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
|
||||
|
||||
features = self.feature_extractor(x)
|
||||
features = self.feature_projection(features)
|
||||
|
||||
batch_size, seq_len, _ = features.shape
|
||||
|
||||
x, all_x = self.encoder(features)
|
||||
|
||||
return x, all_x
|
||||
|
||||
186
comfy/audio_encoders/whisper.py
Executable file
186
comfy/audio_encoders/whisper.py
Executable file
@@ -0,0 +1,186 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from typing import Optional
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
import comfy.ops
|
||||
|
||||
class WhisperFeatureExtractor(nn.Module):
|
||||
def __init__(self, n_mels=128, device=None):
|
||||
super().__init__()
|
||||
self.sample_rate = 16000
|
||||
self.n_fft = 400
|
||||
self.hop_length = 160
|
||||
self.n_mels = n_mels
|
||||
self.chunk_length = 30
|
||||
self.n_samples = 480000
|
||||
|
||||
self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=self.sample_rate,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
norm="slaney",
|
||||
mel_scale="slaney",
|
||||
).to(device)
|
||||
|
||||
def __call__(self, audio):
|
||||
audio = torch.mean(audio, dim=1)
|
||||
batch_size = audio.shape[0]
|
||||
processed_audio = []
|
||||
|
||||
for i in range(batch_size):
|
||||
aud = audio[i]
|
||||
if aud.shape[0] > self.n_samples:
|
||||
aud = aud[:self.n_samples]
|
||||
elif aud.shape[0] < self.n_samples:
|
||||
aud = F.pad(aud, (0, self.n_samples - aud.shape[0]))
|
||||
processed_audio.append(aud)
|
||||
|
||||
audio = torch.stack(processed_audio)
|
||||
|
||||
mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device)
|
||||
|
||||
log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0)
|
||||
log_mel_spec = (log_mel_spec + 4.0) / 4.0
|
||||
|
||||
return log_mel_spec
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert d_model % n_heads == 0
|
||||
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.d_k = d_model // n_heads
|
||||
|
||||
self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||
self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device)
|
||||
self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||
self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, seq_len, _ = query.shape
|
||||
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(value)
|
||||
|
||||
attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations)
|
||||
self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
|
||||
|
||||
self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device)
|
||||
self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device)
|
||||
self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
residual = x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x = self.self_attn(x, x, x, attention_mask)
|
||||
x = residual + x
|
||||
|
||||
residual = x
|
||||
x = self.final_layer_norm(x)
|
||||
x = self.fc1(x)
|
||||
x = F.gelu(x)
|
||||
x = self.fc2(x)
|
||||
x = residual + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_mels: int = 128,
|
||||
n_ctx: int = 1500,
|
||||
n_state: int = 1280,
|
||||
n_head: int = 20,
|
||||
n_layer: int = 32,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device)
|
||||
self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device)
|
||||
|
||||
self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(n_layer)
|
||||
])
|
||||
|
||||
self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x)
|
||||
|
||||
all_x = ()
|
||||
for layer in self.layers:
|
||||
all_x += (x,)
|
||||
x = layer(x)
|
||||
|
||||
x = self.layer_norm(x)
|
||||
all_x += (x,)
|
||||
return x, all_x
|
||||
|
||||
|
||||
class WhisperLargeV3(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_mels: int = 128,
|
||||
n_audio_ctx: int = 1500,
|
||||
n_audio_state: int = 1280,
|
||||
n_audio_head: int = 20,
|
||||
n_audio_layer: int = 32,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device)
|
||||
|
||||
self.encoder = AudioEncoder(
|
||||
n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward(self, audio):
|
||||
mel = self.feature_extractor(audio)
|
||||
x, all_x = self.encoder(mel)
|
||||
return x, all_x
|
||||
@@ -86,24 +86,24 @@ class BatchedBrownianTree:
|
||||
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
||||
|
||||
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
||||
self.cpu_tree = True
|
||||
if "cpu" in kwargs:
|
||||
self.cpu_tree = kwargs.pop("cpu")
|
||||
self.cpu_tree = kwargs.pop("cpu", True)
|
||||
t0, t1, self.sign = self.sort(t0, t1)
|
||||
w0 = kwargs.get('w0', torch.zeros_like(x))
|
||||
w0 = kwargs.pop('w0', None)
|
||||
if w0 is None:
|
||||
w0 = torch.zeros_like(x)
|
||||
self.batched = False
|
||||
if seed is None:
|
||||
seed = torch.randint(0, 2 ** 63 - 1, []).item()
|
||||
self.batched = True
|
||||
try:
|
||||
assert len(seed) == x.shape[0]
|
||||
seed = (torch.randint(0, 2 ** 63 - 1, ()).item(),)
|
||||
elif isinstance(seed, (tuple, list)):
|
||||
if len(seed) != x.shape[0]:
|
||||
raise ValueError("Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size.")
|
||||
self.batched = True
|
||||
w0 = w0[0]
|
||||
except TypeError:
|
||||
seed = [seed]
|
||||
self.batched = False
|
||||
if self.cpu_tree:
|
||||
self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
|
||||
else:
|
||||
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
||||
seed = (seed,)
|
||||
if self.cpu_tree:
|
||||
t0, w0, t1 = t0.detach().cpu(), w0.detach().cpu(), t1.detach().cpu()
|
||||
self.trees = tuple(torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed)
|
||||
|
||||
@staticmethod
|
||||
def sort(a, b):
|
||||
@@ -111,11 +111,10 @@ class BatchedBrownianTree:
|
||||
|
||||
def __call__(self, t0, t1):
|
||||
t0, t1, sign = self.sort(t0, t1)
|
||||
device, dtype = t0.device, t0.dtype
|
||||
if self.cpu_tree:
|
||||
w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
|
||||
else:
|
||||
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
||||
|
||||
t0, t1 = t0.detach().cpu().float(), t1.detach().cpu().float()
|
||||
w = torch.stack([tree(t0, t1) for tree in self.trees]).to(device=device, dtype=dtype) * (self.sign * sign)
|
||||
return w if self.batched else w[0]
|
||||
|
||||
|
||||
|
||||
@@ -606,6 +606,11 @@ class HunyuanImage21(LatentFormat):
|
||||
|
||||
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
|
||||
|
||||
class HunyuanImage21Refiner(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 3
|
||||
scale_factor = 1.03682
|
||||
|
||||
class Hunyuan3Dv2(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
@@ -624,3 +629,20 @@ class Hunyuan3Dv2mini(LatentFormat):
|
||||
class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
class ChromaRadiance(LatentFormat):
|
||||
latent_channels = 3
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
# R G B
|
||||
[ 1.0, 0.0, 0.0 ],
|
||||
[ 0.0, 1.0, 0.0 ],
|
||||
[ 0.0, 0.0, 1.0 ]
|
||||
]
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent
|
||||
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
@@ -133,6 +133,7 @@ class Attention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options={},
|
||||
**cross_attention_kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
@@ -140,6 +141,7 @@ class Attention(nn.Module):
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
@@ -366,6 +368,7 @@ class CustomerAttnProcessor2_0:
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
transformer_options={},
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -433,7 +436,7 @@ class CustomerAttnProcessor2_0:
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
hidden_states = optimized_attention(
|
||||
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
|
||||
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options,
|
||||
).to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
@@ -697,6 +700,7 @@ class LinearTransformerBlock(nn.Module):
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
temb: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
):
|
||||
|
||||
N = hidden_states.shape[0]
|
||||
@@ -720,6 +724,7 @@ class LinearTransformerBlock(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
else:
|
||||
attn_output, _ = self.attn(
|
||||
@@ -729,6 +734,7 @@ class LinearTransformerBlock(nn.Module):
|
||||
encoder_attention_mask=None,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=None,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if self.use_adaln_single:
|
||||
@@ -743,6 +749,7 @@ class LinearTransformerBlock(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
|
||||
@@ -314,6 +314,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
output_length: int = 0,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
transformer_options={},
|
||||
):
|
||||
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
|
||||
temb = self.t_block(embedded_timestep)
|
||||
@@ -339,6 +340,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
||||
temb=temb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||
@@ -393,6 +395,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
|
||||
output_length = hidden_states.shape[-1]
|
||||
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
output = self.decode(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
@@ -402,6 +405,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
||||
output_length=output_length,
|
||||
block_controlnet_hidden_states=block_controlnet_hidden_states,
|
||||
controlnet_scale=controlnet_scale,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@@ -298,7 +298,8 @@ class Attention(nn.Module):
|
||||
mask = None,
|
||||
context_mask = None,
|
||||
rotary_pos_emb = None,
|
||||
causal = None
|
||||
causal = None,
|
||||
transformer_options={},
|
||||
):
|
||||
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
||||
|
||||
@@ -363,7 +364,7 @@ class Attention(nn.Module):
|
||||
heads_per_kv_head = h // kv_h
|
||||
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
|
||||
out = self.to_out(out)
|
||||
|
||||
if mask is not None:
|
||||
@@ -488,7 +489,8 @@ class TransformerBlock(nn.Module):
|
||||
global_cond=None,
|
||||
mask = None,
|
||||
context_mask = None,
|
||||
rotary_pos_emb = None
|
||||
rotary_pos_emb = None,
|
||||
transformer_options={}
|
||||
):
|
||||
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
||||
|
||||
@@ -498,12 +500,12 @@ class TransformerBlock(nn.Module):
|
||||
residual = x
|
||||
x = self.pre_norm(x)
|
||||
x = x * (1 + scale_self) + shift_self
|
||||
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
|
||||
x = x * torch.sigmoid(1 - gate_self)
|
||||
x = x + residual
|
||||
|
||||
if context is not None:
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer(x)
|
||||
@@ -517,10 +519,10 @@ class TransformerBlock(nn.Module):
|
||||
x = x + residual
|
||||
|
||||
else:
|
||||
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
|
||||
|
||||
if context is not None:
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer(x)
|
||||
@@ -606,7 +608,8 @@ class ContinuousTransformer(nn.Module):
|
||||
return_info = False,
|
||||
**kwargs
|
||||
):
|
||||
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
batch, seq, device = *x.shape[:2], x.device
|
||||
context = kwargs["context"]
|
||||
|
||||
@@ -645,13 +648,13 @@ class ContinuousTransformer(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
|
||||
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options)
|
||||
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
|
||||
if return_info:
|
||||
|
||||
@@ -85,7 +85,7 @@ class SingleAttention(nn.Module):
|
||||
)
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c):
|
||||
def forward(self, c, transformer_options={}):
|
||||
|
||||
bsz, seqlen1, _ = c.shape
|
||||
|
||||
@@ -95,7 +95,7 @@ class SingleAttention(nn.Module):
|
||||
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
q, k = self.q_norm1(q), self.k_norm1(k)
|
||||
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||
c = self.w1o(output)
|
||||
return c
|
||||
|
||||
@@ -144,7 +144,7 @@ class DoubleAttention(nn.Module):
|
||||
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c, x):
|
||||
def forward(self, c, x, transformer_options={}):
|
||||
|
||||
bsz, seqlen1, _ = c.shape
|
||||
bsz, seqlen2, _ = x.shape
|
||||
@@ -168,7 +168,7 @@ class DoubleAttention(nn.Module):
|
||||
torch.cat([cv, xv], dim=1),
|
||||
)
|
||||
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
c, x = output.split([seqlen1, seqlen2], dim=1)
|
||||
c = self.w1o(c)
|
||||
@@ -207,7 +207,7 @@ class MMDiTBlock(nn.Module):
|
||||
self.is_last = is_last
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c, x, global_cond, **kwargs):
|
||||
def forward(self, c, x, global_cond, transformer_options={}, **kwargs):
|
||||
|
||||
cres, xres = c, x
|
||||
|
||||
@@ -225,7 +225,7 @@ class MMDiTBlock(nn.Module):
|
||||
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
|
||||
|
||||
# attention
|
||||
c, x = self.attn(c, x)
|
||||
c, x = self.attn(c, x, transformer_options=transformer_options)
|
||||
|
||||
|
||||
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
|
||||
@@ -255,13 +255,13 @@ class DiTBlock(nn.Module):
|
||||
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, cx, global_cond, **kwargs):
|
||||
def forward(self, cx, global_cond, transformer_options={}, **kwargs):
|
||||
cxres = cx
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
|
||||
global_cond
|
||||
).chunk(6, dim=1)
|
||||
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
|
||||
cx = self.attn(cx)
|
||||
cx = self.attn(cx, transformer_options=transformer_options)
|
||||
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
|
||||
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
|
||||
cx = gate_mlp.unsqueeze(1) * mlpout
|
||||
@@ -473,13 +473,14 @@ class MMDiT(nn.Module):
|
||||
out = {}
|
||||
out["txt"], out["img"] = layer(args["txt"],
|
||||
args["img"],
|
||||
args["vec"])
|
||||
args["vec"],
|
||||
transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
c = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
c, x = layer(c, x, global_cond, **kwargs)
|
||||
c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs)
|
||||
|
||||
if len(self.single_layers) > 0:
|
||||
c_len = c.size(1)
|
||||
@@ -488,13 +489,13 @@ class MMDiT(nn.Module):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = layer(args["img"], args["vec"])
|
||||
out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
cx = out["img"]
|
||||
else:
|
||||
cx = layer(cx, global_cond, **kwargs)
|
||||
cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs)
|
||||
|
||||
x = cx[:, c_len:]
|
||||
|
||||
|
||||
@@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module):
|
||||
|
||||
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, q, k, v):
|
||||
def forward(self, q, k, v, transformer_options={}):
|
||||
q = self.to_q(q)
|
||||
k = self.to_k(k)
|
||||
v = self.to_v(v)
|
||||
|
||||
out = optimized_attention(q, k, v, self.heads)
|
||||
out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options)
|
||||
|
||||
return self.out_proj(out)
|
||||
|
||||
@@ -47,13 +47,13 @@ class Attention2D(nn.Module):
|
||||
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
|
||||
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, kv, self_attn=False):
|
||||
def forward(self, x, kv, self_attn=False, transformer_options={}):
|
||||
orig_shape = x.shape
|
||||
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
|
||||
if self_attn:
|
||||
kv = torch.cat([x, kv], dim=1)
|
||||
# x = self.attn(x, kv, kv, need_weights=False)[0]
|
||||
x = self.attn(x, kv, kv)
|
||||
x = self.attn(x, kv, kv, transformer_options=transformer_options)
|
||||
x = x.permute(0, 2, 1).view(*orig_shape)
|
||||
return x
|
||||
|
||||
@@ -114,9 +114,9 @@ class AttnBlock(nn.Module):
|
||||
operations.Linear(c_cond, c, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, kv):
|
||||
def forward(self, x, kv, transformer_options={}):
|
||||
kv = self.kv_mapper(kv)
|
||||
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
||||
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@@ -173,7 +173,7 @@ class StageB(nn.Module):
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
|
||||
def _down_encode(self, x, r_embed, clip):
|
||||
def _down_encode(self, x, r_embed, clip, transformer_options={}):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
@@ -187,7 +187,7 @@ class StageB(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@@ -199,7 +199,7 @@ class StageB(nn.Module):
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip):
|
||||
def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
@@ -216,7 +216,7 @@ class StageB(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@@ -228,7 +228,7 @@ class StageB(nn.Module):
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
|
||||
def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs):
|
||||
if pixels is None:
|
||||
pixels = x.new_zeros(x.size(0), 3, 8, 8)
|
||||
|
||||
@@ -245,8 +245,8 @@ class StageB(nn.Module):
|
||||
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
|
||||
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
level_outputs = self._down_encode(x, r_embed, clip)
|
||||
x = self._up_decode(level_outputs, r_embed, clip)
|
||||
level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options)
|
||||
x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options)
|
||||
return self.clf(x)
|
||||
|
||||
def update_weights_ema(self, src_model, beta=0.999):
|
||||
|
||||
@@ -182,7 +182,7 @@ class StageC(nn.Module):
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
|
||||
def _down_encode(self, x, r_embed, clip, cnet=None):
|
||||
def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
@@ -201,7 +201,7 @@ class StageC(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@@ -213,7 +213,7 @@ class StageC(nn.Module):
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
|
||||
def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
@@ -235,7 +235,7 @@ class StageC(nn.Module):
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
x = block(x, clip, transformer_options=transformer_options)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
@@ -247,7 +247,7 @@ class StageC(nn.Module):
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
|
||||
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs):
|
||||
# Process the conditioning embeddings
|
||||
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
||||
for c in self.t_conds:
|
||||
@@ -262,8 +262,8 @@ class StageC(nn.Module):
|
||||
|
||||
# Model Blocks
|
||||
x = self.embedding(x)
|
||||
level_outputs = self._down_encode(x, r_embed, clip, cnet)
|
||||
x = self._up_decode(level_outputs, r_embed, clip, cnet)
|
||||
level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options)
|
||||
x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options)
|
||||
return self.clf(x)
|
||||
|
||||
def update_weights_ema(self, src_model, beta=0.999):
|
||||
|
||||
@@ -76,7 +76,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
|
||||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
@@ -95,7 +95,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
@@ -148,7 +148,7 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
|
||||
mod = vec
|
||||
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
@@ -157,7 +157,7 @@ class SingleStreamBlock(nn.Module):
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x.addcmul_(mod.gate, output)
|
||||
|
||||
@@ -151,8 +151,6 @@ class Chroma(nn.Module):
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
@@ -193,14 +191,16 @@ class Chroma(nn.Module):
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": double_mod,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
@@ -209,7 +209,8 @@ class Chroma(nn.Module):
|
||||
txt=txt,
|
||||
vec=double_mod,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@@ -229,17 +230,19 @@ class Chroma(nn.Module):
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": single_mod,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
|
||||
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@@ -249,8 +252,9 @@ class Chroma(nn.Module):
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
final_mod = self.get_modulations(mod_vectors, "final")
|
||||
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
if hasattr(self, "final_layer"):
|
||||
final_mod = self.get_modulations(mod_vectors, "final")
|
||||
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
@@ -266,6 +270,9 @@ class Chroma(nn.Module):
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
|
||||
|
||||
if img.ndim != 3 or context.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
|
||||
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
|
||||
206
comfy/ldm/chroma_radiance/layers.py
Normal file
206
comfy/ldm/chroma_radiance/layers.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# Adapted from https://github.com/lodestone-rock/flow
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from comfy.ldm.flux.layers import RMSNorm
|
||||
|
||||
|
||||
class NerfEmbedder(nn.Module):
|
||||
"""
|
||||
An embedder module that combines input features with a 2D positional
|
||||
encoding that mimics the Discrete Cosine Transform (DCT).
|
||||
|
||||
This module takes an input tensor of shape (B, P^2, C), where P is the
|
||||
patch size, and enriches it with positional information before projecting
|
||||
it to a new hidden size.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
hidden_size_input: int,
|
||||
max_freqs: int,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
"""
|
||||
Initializes the NerfEmbedder.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input tensor.
|
||||
hidden_size_input (int): The desired dimension of the output embedding.
|
||||
max_freqs (int): The number of frequency components to use for both
|
||||
the x and y dimensions of the positional encoding.
|
||||
The total number of positional features will be max_freqs^2.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.max_freqs = max_freqs
|
||||
self.hidden_size_input = hidden_size_input
|
||||
|
||||
# A linear layer to project the concatenated input features and
|
||||
# positional encodings to the final output dimension.
|
||||
self.embedder = nn.Sequential(
|
||||
operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=4)
|
||||
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""
|
||||
Generates and caches 2D DCT-like positional embeddings for a given patch size.
|
||||
|
||||
The LRU cache is a performance optimization that avoids recomputing the
|
||||
same positional grid on every forward pass.
|
||||
|
||||
Args:
|
||||
patch_size (int): The side length of the square input patch.
|
||||
device: The torch device to create the tensors on.
|
||||
dtype: The torch dtype for the tensors.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (1, patch_size^2, max_freqs^2) containing the
|
||||
positional embeddings.
|
||||
"""
|
||||
# Create normalized 1D coordinate grids from 0 to 1.
|
||||
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||
|
||||
# Create a 2D meshgrid of coordinates.
|
||||
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
|
||||
|
||||
# Reshape positions to be broadcastable with frequencies.
|
||||
# Shape becomes (patch_size^2, 1, 1).
|
||||
pos_x = pos_x.reshape(-1, 1, 1)
|
||||
pos_y = pos_y.reshape(-1, 1, 1)
|
||||
|
||||
# Create a 1D tensor of frequency values from 0 to max_freqs-1.
|
||||
freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device)
|
||||
|
||||
# Reshape frequencies to be broadcastable for creating 2D basis functions.
|
||||
# freqs_x shape: (1, max_freqs, 1)
|
||||
# freqs_y shape: (1, 1, max_freqs)
|
||||
freqs_x = freqs[None, :, None]
|
||||
freqs_y = freqs[None, None, :]
|
||||
|
||||
# A custom weighting coefficient, not part of standard DCT.
|
||||
# This seems to down-weight the contribution of higher-frequency interactions.
|
||||
coeffs = (1 + freqs_x * freqs_y) ** -1
|
||||
|
||||
# Calculate the 1D cosine basis functions for x and y coordinates.
|
||||
# This is the core of the DCT formulation.
|
||||
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
|
||||
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
|
||||
|
||||
# Combine the 1D basis functions to create 2D basis functions by element-wise
|
||||
# multiplication, and apply the custom coefficients. Broadcasting handles the
|
||||
# combination of all (pos_x, freqs_x) with all (pos_y, freqs_y).
|
||||
# The result is flattened into a feature vector for each position.
|
||||
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2)
|
||||
|
||||
return dct
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the embedder.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input tensor of shape (B, P^2, C).
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor of shape (B, P^2, hidden_size_input).
|
||||
"""
|
||||
# Get the batch size, number of pixels, and number of channels.
|
||||
B, P2, C = inputs.shape
|
||||
|
||||
# Infer the patch side length from the number of pixels (P^2).
|
||||
patch_size = int(P2 ** 0.5)
|
||||
|
||||
input_dtype = inputs.dtype
|
||||
inputs = inputs.to(dtype=self.dtype)
|
||||
|
||||
# Fetch the pre-computed or cached positional embeddings.
|
||||
dct = self.fetch_pos(patch_size, inputs.device, self.dtype)
|
||||
|
||||
# Repeat the positional embeddings for each item in the batch.
|
||||
dct = dct.repeat(B, 1, 1)
|
||||
|
||||
# Concatenate the original input features with the positional embeddings
|
||||
# along the feature dimension.
|
||||
inputs = torch.cat((inputs, dct), dim=-1)
|
||||
|
||||
# Project the combined tensor to the target hidden size.
|
||||
return self.embedder(inputs).to(dtype=input_dtype)
|
||||
|
||||
|
||||
class NerfGLUBlock(nn.Module):
|
||||
"""
|
||||
A NerfBlock using a Gated Linear Unit (GLU) like MLP.
|
||||
"""
|
||||
def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
# The total number of parameters for the MLP is increased to accommodate
|
||||
# the gate, value, and output projection matrices.
|
||||
# We now need to generate parameters for 3 matrices.
|
||||
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
||||
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
||||
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_x, hidden_size_x = x.shape
|
||||
mlp_params = self.param_generator(s)
|
||||
|
||||
# Split the generated parameters into three parts for the gate, value, and output projection.
|
||||
fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1)
|
||||
|
||||
# Reshape the parameters into matrices for batch matrix multiplication.
|
||||
fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
|
||||
fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
|
||||
fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x)
|
||||
|
||||
# Normalize the generated weight matrices as in the original implementation.
|
||||
fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2)
|
||||
fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2)
|
||||
fc2 = torch.nn.functional.normalize(fc2, dim=-2)
|
||||
|
||||
res_x = x
|
||||
x = self.norm(x)
|
||||
|
||||
# Apply the final output projection.
|
||||
x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2)
|
||||
|
||||
return x + res_x
|
||||
|
||||
|
||||
class NerfFinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
|
||||
# So we temporarily move the channel dimension to the end for the norm operation.
|
||||
return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1)
|
||||
|
||||
|
||||
class NerfFinalLayerConv(nn.Module):
|
||||
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.conv = operations.Conv2d(
|
||||
in_channels=hidden_size,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
|
||||
# So we temporarily move the channel dimension to the end for the norm operation.
|
||||
return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1))
|
||||
329
comfy/ldm/chroma_radiance/model.py
Normal file
329
comfy/ldm/chroma_radiance/model.py
Normal file
@@ -0,0 +1,329 @@
|
||||
# Credits:
|
||||
# Original Flux code can be found on: https://github.com/black-forest-labs/flux
|
||||
# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
|
||||
from comfy.ldm.chroma.model import Chroma, ChromaParams
|
||||
from comfy.ldm.chroma.layers import (
|
||||
DoubleStreamBlock,
|
||||
SingleStreamBlock,
|
||||
Approximator,
|
||||
)
|
||||
from .layers import (
|
||||
NerfEmbedder,
|
||||
NerfGLUBlock,
|
||||
NerfFinalLayer,
|
||||
NerfFinalLayerConv,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChromaRadianceParams(ChromaParams):
|
||||
patch_size: int
|
||||
nerf_hidden_size: int
|
||||
nerf_mlp_ratio: int
|
||||
nerf_depth: int
|
||||
nerf_max_freqs: int
|
||||
# Setting nerf_tile_size to 0 disables tiling.
|
||||
nerf_tile_size: int
|
||||
# Currently one of linear (legacy) or conv.
|
||||
nerf_final_head_type: str
|
||||
# None means use the same dtype as the model.
|
||||
nerf_embedder_dtype: Optional[torch.dtype]
|
||||
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
if operations is None:
|
||||
raise RuntimeError("Attempt to create ChromaRadiance object without setting operations")
|
||||
nn.Module.__init__(self)
|
||||
self.dtype = dtype
|
||||
params = ChromaRadianceParams(**kwargs)
|
||||
self.params = params
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.out_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.in_dim = params.in_dim
|
||||
self.out_dim = params.out_dim
|
||||
self.hidden_dim = params.hidden_dim
|
||||
self.n_layers = params.n_layers
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in_patch = operations.Conv2d(
|
||||
params.in_channels,
|
||||
params.hidden_size,
|
||||
kernel_size=params.patch_size,
|
||||
stride=params.patch_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||
# set as nn identity for now, will overwrite it later.
|
||||
self.distilled_guidance_layer = Approximator(
|
||||
in_dim=self.in_dim,
|
||||
hidden_dim=self.hidden_dim,
|
||||
out_dim=self.out_dim,
|
||||
n_layers=self.n_layers,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
# pixel channel concat with DCT
|
||||
self.nerf_image_embedder = NerfEmbedder(
|
||||
in_channels=params.in_channels,
|
||||
hidden_size_input=params.nerf_hidden_size,
|
||||
max_freqs=params.nerf_max_freqs,
|
||||
dtype=params.nerf_embedder_dtype or dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.nerf_blocks = nn.ModuleList([
|
||||
NerfGLUBlock(
|
||||
hidden_size_s=params.hidden_size,
|
||||
hidden_size_x=params.nerf_hidden_size,
|
||||
mlp_ratio=params.nerf_mlp_ratio,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
) for _ in range(params.nerf_depth)
|
||||
])
|
||||
|
||||
if params.nerf_final_head_type == "linear":
|
||||
self.nerf_final_layer = NerfFinalLayer(
|
||||
params.nerf_hidden_size,
|
||||
out_channels=params.in_channels,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
elif params.nerf_final_head_type == "conv":
|
||||
self.nerf_final_layer_conv = NerfFinalLayerConv(
|
||||
params.nerf_hidden_size,
|
||||
out_channels=params.in_channels,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
else:
|
||||
errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}"
|
||||
raise ValueError(errstr)
|
||||
|
||||
self.skip_mmdit = []
|
||||
self.skip_dit = []
|
||||
self.lite = False
|
||||
|
||||
@property
|
||||
def _nerf_final_layer(self) -> nn.Module:
|
||||
if self.params.nerf_final_head_type == "linear":
|
||||
return self.nerf_final_layer
|
||||
if self.params.nerf_final_head_type == "conv":
|
||||
return self.nerf_final_layer_conv
|
||||
# Impossible to get here as we raise an error on unexpected types on initialization.
|
||||
raise NotImplementedError
|
||||
|
||||
def img_in(self, img: Tensor) -> Tensor:
|
||||
img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P]
|
||||
# flatten into a sequence for the transformer.
|
||||
return img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden]
|
||||
|
||||
def forward_nerf(
|
||||
self,
|
||||
img_orig: Tensor,
|
||||
img_out: Tensor,
|
||||
params: ChromaRadianceParams,
|
||||
) -> Tensor:
|
||||
B, C, H, W = img_orig.shape
|
||||
num_patches = img_out.shape[1]
|
||||
patch_size = params.patch_size
|
||||
|
||||
# Store the raw pixel values of each patch for the NeRF head later.
|
||||
# unfold creates patches: [B, C * P * P, NumPatches]
|
||||
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
|
||||
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
|
||||
|
||||
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
|
||||
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
|
||||
# the tile size.
|
||||
img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
|
||||
else:
|
||||
# Reshape for per-patch processing
|
||||
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
|
||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
||||
|
||||
# Get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct = self.nerf_image_embedder(nerf_pixels)
|
||||
|
||||
# Pass through the dynamic MLP blocks (the NeRF)
|
||||
for block in self.nerf_blocks:
|
||||
img_dct = block(img_dct, nerf_hidden)
|
||||
|
||||
# Reassemble the patches into the final image.
|
||||
img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P]
|
||||
# Reshape to combine with batch dimension for fold
|
||||
img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P]
|
||||
img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches]
|
||||
img_dct = nn.functional.fold(
|
||||
img_dct,
|
||||
output_size=(H, W),
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
)
|
||||
return self._nerf_final_layer(img_dct)
|
||||
|
||||
def forward_tiled_nerf(
|
||||
self,
|
||||
nerf_hidden: Tensor,
|
||||
nerf_pixels: Tensor,
|
||||
batch: int,
|
||||
channels: int,
|
||||
num_patches: int,
|
||||
patch_size: int,
|
||||
params: ChromaRadianceParams,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Processes the NeRF head in tiles to save memory.
|
||||
nerf_hidden has shape [B, L, D]
|
||||
nerf_pixels has shape [B, L, C * P * P]
|
||||
"""
|
||||
tile_size = params.nerf_tile_size
|
||||
output_tiles = []
|
||||
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
|
||||
for i in range(0, num_patches, tile_size):
|
||||
end = min(i + tile_size, num_patches)
|
||||
|
||||
# Slice the current tile from the input tensors
|
||||
nerf_hidden_tile = nerf_hidden[:, i:end, :]
|
||||
nerf_pixels_tile = nerf_pixels[:, i:end, :]
|
||||
|
||||
# Get the actual number of patches in this tile (can be smaller for the last tile)
|
||||
num_patches_tile = nerf_hidden_tile.shape[1]
|
||||
|
||||
# Reshape the tile for per-patch processing
|
||||
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
|
||||
nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size)
|
||||
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
|
||||
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
|
||||
|
||||
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
|
||||
|
||||
# pass through the dynamic MLP blocks (the NeRF)
|
||||
for block in self.nerf_blocks:
|
||||
img_dct_tile = block(img_dct_tile, nerf_hidden_tile)
|
||||
|
||||
output_tiles.append(img_dct_tile)
|
||||
|
||||
# Concatenate the processed tiles along the patch dimension
|
||||
return torch.cat(output_tiles, dim=0)
|
||||
|
||||
def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams:
|
||||
params = self.params
|
||||
if not overrides:
|
||||
return params
|
||||
params_dict = {k: getattr(params, k) for k in params.__dataclass_fields__}
|
||||
nullable_keys = frozenset(("nerf_embedder_dtype",))
|
||||
bad_keys = tuple(k for k in overrides if k not in params_dict)
|
||||
if bad_keys:
|
||||
e = f"Unknown key(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||
raise ValueError(e)
|
||||
bad_keys = tuple(
|
||||
k
|
||||
for k, v in overrides.items()
|
||||
if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
|
||||
)
|
||||
if bad_keys:
|
||||
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||
raise ValueError(e)
|
||||
# At this point it's all valid keys and values so we can merge with the existing params.
|
||||
params_dict |= overrides
|
||||
return params.__class__(**params_dict)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
timestep: Tensor,
|
||||
context: Tensor,
|
||||
guidance: Optional[Tensor],
|
||||
control: Optional[dict]=None,
|
||||
transformer_options: dict={},
|
||||
**kwargs: dict,
|
||||
) -> Tensor:
|
||||
bs, c, h, w = x.shape
|
||||
img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
|
||||
if img.ndim != 4:
|
||||
raise ValueError("Input img tensor must be in [B, C, H, W] format.")
|
||||
if context.ndim != 3:
|
||||
raise ValueError("Input txt tensors must have 3 dimensions.")
|
||||
|
||||
params = self.radiance_get_override_params(transformer_options.get("chroma_radiance_options", {}))
|
||||
|
||||
h_len = (img.shape[-2] // self.patch_size)
|
||||
w_len = (img.shape[-1] // self.patch_size)
|
||||
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
|
||||
img_out = self.forward_orig(
|
||||
img,
|
||||
img_ids,
|
||||
context,
|
||||
txt_ids,
|
||||
timestep,
|
||||
guidance,
|
||||
control,
|
||||
transformer_options,
|
||||
attn_mask=kwargs.get("attention_mask", None),
|
||||
)
|
||||
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
|
||||
@@ -176,6 +176,7 @@ class Attention(nn.Module):
|
||||
context=None,
|
||||
mask=None,
|
||||
rope_emb=None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -184,7 +185,7 @@ class Attention(nn.Module):
|
||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||
"""
|
||||
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
out = rearrange(out, " b n s c -> s b (n c)")
|
||||
return self.to_out(out)
|
||||
@@ -546,6 +547,7 @@ class VideoAttn(nn.Module):
|
||||
context: Optional[torch.Tensor] = None,
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for video attention.
|
||||
@@ -571,6 +573,7 @@ class VideoAttn(nn.Module):
|
||||
context_M_B_D,
|
||||
crossattn_mask,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
|
||||
return x_T_H_W_B_D
|
||||
@@ -665,6 +668,7 @@ class DITBuildingBlock(nn.Module):
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for dynamically configured blocks with adaptive normalization.
|
||||
@@ -702,6 +706,7 @@ class DITBuildingBlock(nn.Module):
|
||||
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||
context=None,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
elif self.block_type in ["cross_attn", "ca"]:
|
||||
x = x + gate_1_1_1_B_D * self.block(
|
||||
@@ -709,6 +714,7 @@ class DITBuildingBlock(nn.Module):
|
||||
context=crossattn_emb,
|
||||
crossattn_mask=crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown block type: {self.block_type}")
|
||||
@@ -784,6 +790,7 @@ class GeneralDITTransformerBlock(nn.Module):
|
||||
crossattn_mask: Optional[torch.Tensor] = None,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
for block in self.blocks:
|
||||
x = block(
|
||||
@@ -793,5 +800,6 @@ class GeneralDITTransformerBlock(nn.Module):
|
||||
crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
return x
|
||||
|
||||
@@ -520,6 +520,7 @@ class GeneralDIT(nn.Module):
|
||||
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
|
||||
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
for _, block in self.blocks.items():
|
||||
assert (
|
||||
self.blocks["block0"].x_format == block.x_format
|
||||
@@ -534,6 +535,7 @@ class GeneralDIT(nn.Module):
|
||||
crossattn_mask,
|
||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
||||
|
||||
@@ -44,7 +44,7 @@ class GPT2FeedForward(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
|
||||
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
|
||||
"""Computes multi-head attention using PyTorch's native implementation.
|
||||
|
||||
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
|
||||
@@ -71,7 +71,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
|
||||
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
|
||||
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
|
||||
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
@@ -180,8 +180,8 @@ class Attention(nn.Module):
|
||||
|
||||
return q, k, v
|
||||
|
||||
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
result = self.attn_op(q, k, v) # [B, S, H, D]
|
||||
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
|
||||
result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D]
|
||||
return self.output_dropout(self.output_proj(result))
|
||||
|
||||
def forward(
|
||||
@@ -189,6 +189,7 @@ class Attention(nn.Module):
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
rope_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@@ -196,7 +197,7 @@ class Attention(nn.Module):
|
||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||
"""
|
||||
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
||||
return self.compute_attention(q, k, v)
|
||||
return self.compute_attention(q, k, v, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
@@ -459,6 +460,7 @@ class Block(nn.Module):
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||
@@ -512,6 +514,7 @@ class Block(nn.Module):
|
||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
None,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
),
|
||||
"b (t h w) d -> b t h w d",
|
||||
t=T,
|
||||
@@ -525,6 +528,7 @@ class Block(nn.Module):
|
||||
layer_norm_cross_attn: Callable,
|
||||
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
_normalized_x_B_T_H_W_D = _fn(
|
||||
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
|
||||
@@ -534,6 +538,7 @@ class Block(nn.Module):
|
||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
crossattn_emb,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
),
|
||||
"b (t h w) d -> b t h w d",
|
||||
t=T,
|
||||
@@ -547,6 +552,7 @@ class Block(nn.Module):
|
||||
self.layer_norm_cross_attn,
|
||||
scale_cross_attn_B_T_1_1_D,
|
||||
shift_cross_attn_B_T_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||
|
||||
@@ -865,6 +871,7 @@ class MiniTrainDIT(nn.Module):
|
||||
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
|
||||
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
|
||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
"transformer_options": kwargs.get("transformer_options", {}),
|
||||
}
|
||||
for block in self.blocks:
|
||||
x_B_T_H_W_D = block(
|
||||
|
||||
@@ -159,7 +159,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
@@ -182,7 +182,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
||||
torch.cat((img_k, txt_k), dim=2),
|
||||
torch.cat((img_v, txt_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
@@ -190,7 +190,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
@@ -244,7 +244,7 @@ class SingleStreamBlock(nn.Module):
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
@@ -252,7 +252,7 @@ class SingleStreamBlock(nn.Module):
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||
|
||||
@@ -6,7 +6,7 @@ from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
|
||||
@@ -17,7 +17,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
@@ -35,11 +35,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1]
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
|
||||
@@ -144,14 +144,16 @@ class Flux(nn.Module):
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
@@ -160,7 +162,8 @@ class Flux(nn.Module):
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@@ -181,17 +184,19 @@ class Flux(nn.Module):
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
|
||||
@@ -109,6 +109,7 @@ class AsymmetricAttention(nn.Module):
|
||||
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||
crop_y,
|
||||
transformer_options={},
|
||||
**rope_rotation,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
rope_cos = rope_rotation.get("rope_cos")
|
||||
@@ -143,7 +144,7 @@ class AsymmetricAttention(nn.Module):
|
||||
|
||||
xy = optimized_attention(q,
|
||||
k,
|
||||
v, self.num_heads, skip_reshape=True)
|
||||
v, self.num_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
||||
x = self.proj_x(x)
|
||||
@@ -224,6 +225,7 @@ class AsymmetricJointBlock(nn.Module):
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
transformer_options={},
|
||||
**attn_kwargs,
|
||||
):
|
||||
"""Forward pass of a block.
|
||||
@@ -256,6 +258,7 @@ class AsymmetricJointBlock(nn.Module):
|
||||
y,
|
||||
scale_x=scale_msa_x,
|
||||
scale_y=scale_msa_y,
|
||||
transformer_options=transformer_options,
|
||||
**attn_kwargs,
|
||||
)
|
||||
|
||||
@@ -524,10 +527,11 @@ class AsymmDiTJoint(nn.Module):
|
||||
args["txt"],
|
||||
rope_cos=args["rope_cos"],
|
||||
rope_sin=args["rope_sin"],
|
||||
crop_y=args["num_tokens"]
|
||||
crop_y=args["num_tokens"],
|
||||
transformer_options=args["transformer_options"]
|
||||
)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
y_feat = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
@@ -538,6 +542,7 @@ class AsymmDiTJoint(nn.Module):
|
||||
rope_cos=rope_cos,
|
||||
rope_sin=rope_sin,
|
||||
crop_y=num_tokens,
|
||||
transformer_options=transformer_options,
|
||||
) # (B, M, D), (B, L, D)
|
||||
del y_feat # Final layers don't use dense text features.
|
||||
|
||||
|
||||
@@ -72,8 +72,8 @@ class TimestepEmbed(nn.Module):
|
||||
return t_emb
|
||||
|
||||
|
||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
||||
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
|
||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}):
|
||||
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options)
|
||||
|
||||
|
||||
class HiDreamAttnProcessor_flashattn:
|
||||
@@ -86,6 +86,7 @@ class HiDreamAttnProcessor_flashattn:
|
||||
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
@@ -133,7 +134,7 @@ class HiDreamAttnProcessor_flashattn:
|
||||
query = torch.cat([query_1, query_2], dim=-1)
|
||||
key = torch.cat([key_1, key_2], dim=-1)
|
||||
|
||||
hidden_states = attention(query, key, value)
|
||||
hidden_states = attention(query, key, value, transformer_options=transformer_options)
|
||||
|
||||
if not attn.single:
|
||||
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
|
||||
@@ -199,6 +200,7 @@ class HiDreamAttention(nn.Module):
|
||||
image_tokens_masks: torch.FloatTensor = None,
|
||||
norm_text_tokens: torch.FloatTensor = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
self,
|
||||
@@ -206,6 +208,7 @@ class HiDreamAttention(nn.Module):
|
||||
image_tokens_masks = image_tokens_masks,
|
||||
text_tokens = norm_text_tokens,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
|
||||
@@ -406,7 +409,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
|
||||
transformer_options={},
|
||||
) -> torch.FloatTensor:
|
||||
wtype = image_tokens.dtype
|
||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
|
||||
@@ -419,6 +422,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||
norm_image_tokens,
|
||||
image_tokens_masks,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||
|
||||
@@ -483,6 +487,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: Optional[torch.FloatTensor] = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
) -> torch.FloatTensor:
|
||||
wtype = image_tokens.dtype
|
||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
|
||||
@@ -500,6 +505,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
||||
image_tokens_masks,
|
||||
norm_text_tokens,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||
@@ -550,6 +556,7 @@ class HiDreamImageBlock(nn.Module):
|
||||
text_tokens: Optional[torch.FloatTensor] = None,
|
||||
adaln_input: torch.FloatTensor = None,
|
||||
rope: torch.FloatTensor = None,
|
||||
transformer_options={},
|
||||
) -> torch.FloatTensor:
|
||||
return self.block(
|
||||
image_tokens,
|
||||
@@ -557,6 +564,7 @@ class HiDreamImageBlock(nn.Module):
|
||||
text_tokens,
|
||||
adaln_input,
|
||||
rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
|
||||
@@ -786,6 +794,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
||||
text_tokens = cur_encoder_hidden_states,
|
||||
adaln_input = adaln_input,
|
||||
rope = rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
|
||||
block_id += 1
|
||||
@@ -809,6 +818,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
||||
text_tokens=None,
|
||||
adaln_input=adaln_input,
|
||||
rope=rope,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
||||
block_id += 1
|
||||
|
||||
@@ -99,14 +99,16 @@ class Hunyuan3Dv2(nn.Module):
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
@@ -115,7 +117,8 @@ class Hunyuan3Dv2(nn.Module):
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
@@ -126,17 +129,19 @@ class Hunyuan3Dv2(nn.Module):
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
"attn_mask": attn_mask,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
img = img[:, txt.shape[1]:, ...]
|
||||
img = self.final_layer(img, vec)
|
||||
|
||||
@@ -80,13 +80,13 @@ class TokenRefinerBlock(nn.Module):
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x, c, mask):
|
||||
def forward(self, x, c, mask, transformer_options={}):
|
||||
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
|
||||
norm_x = self.norm1(x)
|
||||
qkv = self.self_attn.qkv(norm_x)
|
||||
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
|
||||
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
|
||||
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
|
||||
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
|
||||
@@ -117,14 +117,14 @@ class IndividualTokenRefiner(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, c, mask):
|
||||
def forward(self, x, c, mask, transformer_options={}):
|
||||
m = None
|
||||
if mask is not None:
|
||||
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
|
||||
m = m + m.transpose(2, 3)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, c, m)
|
||||
x = block(x, c, m, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
@@ -152,6 +152,7 @@ class TokenRefiner(nn.Module):
|
||||
x,
|
||||
timesteps,
|
||||
mask,
|
||||
transformer_options={},
|
||||
):
|
||||
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
||||
# m = mask.float().unsqueeze(-1)
|
||||
@@ -160,7 +161,7 @@ class TokenRefiner(nn.Module):
|
||||
|
||||
c = t + self.c_embedder(c.to(x.dtype))
|
||||
x = self.input_embedder(x)
|
||||
x = self.individual_token_refiner(x, c, mask)
|
||||
x = self.individual_token_refiner(x, c, mask, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
@@ -278,6 +279,7 @@ class HunyuanVideo(nn.Module):
|
||||
guidance: Tensor = None,
|
||||
guiding_frame_index=None,
|
||||
ref_latent=None,
|
||||
disable_time_r=False,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
@@ -288,7 +290,7 @@ class HunyuanVideo(nn.Module):
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||
|
||||
if self.time_r_in is not None:
|
||||
if (self.time_r_in is not None) and (not disable_time_r):
|
||||
w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
|
||||
if len(w) > 0:
|
||||
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
|
||||
@@ -327,7 +329,7 @@ class HunyuanVideo(nn.Module):
|
||||
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
||||
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
|
||||
|
||||
txt = self.txt_in(txt, timesteps, txt_mask)
|
||||
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
|
||||
|
||||
if self.byt5_in is not None and txt_byt5 is not None:
|
||||
txt_byt5 = self.byt5_in(txt_byt5)
|
||||
@@ -351,14 +353,14 @@ class HunyuanVideo(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt, 'transformer_options': transformer_options}, {"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@@ -373,13 +375,13 @@ class HunyuanVideo(nn.Module):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims, 'transformer_options': transformer_options}, {"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims, transformer_options=transformer_options)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@@ -428,14 +430,14 @@ class HunyuanVideo(nn.Module):
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
|
||||
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
||||
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
bs = x.shape[0]
|
||||
if len(self.patch_size) == 3:
|
||||
img_ids = self.img_ids(x)
|
||||
@@ -443,5 +445,5 @@ class HunyuanVideo(nn.Module):
|
||||
else:
|
||||
img_ids = self.img_ids_2d(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
|
||||
return out
|
||||
|
||||
267
comfy/ldm/hunyuan_video/vae_refiner.py
Normal file
267
comfy/ldm/hunyuan_video/vae_refiner.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d
|
||||
import comfy.ops
|
||||
import comfy.ldm.models.autoencoder
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
shape = (dim, 1, 1, 1)
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.empty(shape))
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=1) * self.scale * self.gamma
|
||||
|
||||
class DnSmpl(nn.Module):
|
||||
def __init__(self, ic, oc, tds=True):
|
||||
super().__init__()
|
||||
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
|
||||
assert oc % fct == 0
|
||||
self.conv = VideoConv3d(ic, oc // fct, kernel_size=3)
|
||||
|
||||
self.tds = tds
|
||||
self.gs = fct * ic // oc
|
||||
|
||||
def forward(self, x):
|
||||
r1 = 2 if self.tds else 1
|
||||
h = self.conv(x)
|
||||
|
||||
if self.tds:
|
||||
hf = h[:, :, :1, :, :]
|
||||
b, c, f, ht, wd = hf.shape
|
||||
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
|
||||
hf = hf.permute(0, 4, 6, 1, 2, 3, 5)
|
||||
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
|
||||
hf = torch.cat([hf, hf], dim=1)
|
||||
|
||||
hn = h[:, :, 1:, :, :]
|
||||
b, c, frms, ht, wd = hn.shape
|
||||
nf = frms // r1
|
||||
hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||
|
||||
h = torch.cat([hf, hn], dim=2)
|
||||
|
||||
xf = x[:, :, :1, :, :]
|
||||
b, ci, f, ht, wd = xf.shape
|
||||
xf = xf.reshape(b, ci, f, ht // 2, 2, wd // 2, 2)
|
||||
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
|
||||
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
|
||||
B, C, T, H, W = xf.shape
|
||||
xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2)
|
||||
|
||||
xn = x[:, :, 1:, :, :]
|
||||
b, ci, frms, ht, wd = xn.shape
|
||||
nf = frms // r1
|
||||
xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||
B, C, T, H, W = xn.shape
|
||||
xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||
sc = torch.cat([xf, xn], dim=2)
|
||||
else:
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nf = frms // r1
|
||||
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||
|
||||
b, ci, frms, ht, wd = x.shape
|
||||
nf = frms // r1
|
||||
sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||
B, C, T, H, W = sc.shape
|
||||
sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||
|
||||
return h + sc
|
||||
|
||||
|
||||
class UpSmpl(nn.Module):
|
||||
def __init__(self, ic, oc, tus=True):
|
||||
super().__init__()
|
||||
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
|
||||
self.conv = VideoConv3d(ic, oc * fct, kernel_size=3)
|
||||
|
||||
self.tus = tus
|
||||
self.rp = fct * oc // ic
|
||||
|
||||
def forward(self, x):
|
||||
r1 = 2 if self.tus else 1
|
||||
h = self.conv(x)
|
||||
|
||||
if self.tus:
|
||||
hf = h[:, :, :1, :, :]
|
||||
b, c, f, ht, wd = hf.shape
|
||||
nc = c // (2 * 2)
|
||||
hf = hf.reshape(b, 2, 2, nc, f, ht, wd)
|
||||
hf = hf.permute(0, 3, 4, 5, 1, 6, 2)
|
||||
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||
hf = hf[:, : hf.shape[1] // 2]
|
||||
|
||||
hn = h[:, :, 1:, :, :]
|
||||
b, c, frms, ht, wd = hn.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
h = torch.cat([hf, hn], dim=2)
|
||||
|
||||
xf = x[:, :, :1, :, :]
|
||||
b, ci, f, ht, wd = xf.shape
|
||||
xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1)
|
||||
b, c, f, ht, wd = xf.shape
|
||||
nc = c // (2 * 2)
|
||||
xf = xf.reshape(b, 2, 2, nc, f, ht, wd)
|
||||
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
|
||||
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||
|
||||
xn = x[:, :, 1:, :, :]
|
||||
xn = xn.repeat_interleave(repeats=self.rp, dim=1)
|
||||
b, c, frms, ht, wd = xn.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
sc = torch.cat([xf, xn], dim=2)
|
||||
else:
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
sc = x.repeat_interleave(repeats=self.rp, dim=1)
|
||||
b, c, frms, ht, wd = sc.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
return h + sc
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_):
|
||||
super().__init__()
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1)
|
||||
|
||||
self.down = nn.ModuleList()
|
||||
ch = block_out_channels[0]
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
for j in range(num_res_blocks)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
||||
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal)
|
||||
ch = nxt
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
|
||||
self.norm_out = RMS_norm(ch)
|
||||
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1)
|
||||
|
||||
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
|
||||
for stage in self.down:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'downsample'):
|
||||
x = stage.downsample(x)
|
||||
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
b, c, t, h, w = x.shape
|
||||
grp = c // (self.z_channels << 1)
|
||||
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
|
||||
|
||||
out = self.conv_out(F.silu(self.norm_out(x))) + skip
|
||||
out = self.regul(out)[0]
|
||||
|
||||
out = torch.cat((out[:, :, :1], out), dim=2)
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
b, f_times_2, c, h, w = out.shape
|
||||
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||
return out
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_):
|
||||
super().__init__()
|
||||
block_out_channels = block_out_channels[::-1]
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
ch = block_out_channels[0]
|
||||
self.conv_in = VideoConv3d(z_channels, ch, 3)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
depth_temporal = (ffactor_temporal >> 1).bit_length()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
||||
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal)
|
||||
ch = nxt
|
||||
self.up.append(stage)
|
||||
|
||||
self.norm_out = RMS_norm(ch)
|
||||
self.conv_out = VideoConv3d(ch, out_channels, 3)
|
||||
|
||||
def forward(self, z):
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
b, f, c, h, w = z.shape
|
||||
z = z.reshape(b, f, 2, c // 2, h, w)
|
||||
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
z = z[:, :, 1:]
|
||||
|
||||
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'upsample'):
|
||||
x = stage.upsample(x)
|
||||
|
||||
return self.conv_out(F.silu(self.norm_out(x)))
|
||||
@@ -271,7 +271,7 @@ class CrossAttention(nn.Module):
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, context=None, mask=None, pe=None):
|
||||
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
k = self.to_k(context)
|
||||
@@ -285,9 +285,9 @@ class CrossAttention(nn.Module):
|
||||
k = apply_rotary_emb(k, pe)
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
@@ -303,12 +303,12 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask)
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
||||
x += self.ff(y) * gate_mlp
|
||||
@@ -479,10 +479,10 @@ class LTXVModel(torch.nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(
|
||||
@@ -490,7 +490,8 @@ class LTXVModel(torch.nn.Module):
|
||||
context=context,
|
||||
attention_mask=attention_mask,
|
||||
timestep=timestep,
|
||||
pe=pe
|
||||
pe=pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
|
||||
@@ -104,6 +104,7 @@ class JointAttention(nn.Module):
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
transformer_options={},
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
@@ -140,7 +141,7 @@ class JointAttention(nn.Module):
|
||||
if n_rep >= 1:
|
||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
|
||||
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
return self.out(output)
|
||||
|
||||
@@ -268,6 +269,7 @@ class JointTransformerBlock(nn.Module):
|
||||
x_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor]=None,
|
||||
transformer_options={},
|
||||
):
|
||||
"""
|
||||
Perform a forward pass through the TransformerBlock.
|
||||
@@ -290,6 +292,7 @@ class JointTransformerBlock(nn.Module):
|
||||
modulate(self.attention_norm1(x), scale_msa),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
)
|
||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||
@@ -304,6 +307,7 @@ class JointTransformerBlock(nn.Module):
|
||||
self.attention_norm1(x),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
)
|
||||
x = x + self.ffn_norm2(
|
||||
@@ -494,7 +498,7 @@ class NextDiT(nn.Module):
|
||||
return imgs
|
||||
|
||||
def patchify_and_embed(
|
||||
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
|
||||
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
||||
bsz = len(x)
|
||||
pH = pW = self.patch_size
|
||||
@@ -554,7 +558,7 @@ class NextDiT(nn.Module):
|
||||
|
||||
# refine context
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
|
||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
|
||||
|
||||
# refine image
|
||||
flat_x = []
|
||||
@@ -573,7 +577,7 @@ class NextDiT(nn.Module):
|
||||
padded_img_embed = self.x_embedder(padded_img_embed)
|
||||
padded_img_mask = padded_img_mask.unsqueeze(1)
|
||||
for layer in self.noise_refiner:
|
||||
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
|
||||
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
|
||||
|
||||
if cap_mask is not None:
|
||||
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
|
||||
@@ -616,12 +620,13 @@ class NextDiT(nn.Module):
|
||||
|
||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
|
||||
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
||||
freqs_cis = freqs_cis.to(x.device)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, freqs_cis, adaln_input)
|
||||
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
|
||||
x = self.final_layer(x, adaln_input)
|
||||
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
|
||||
|
||||
@@ -26,6 +26,12 @@ class DiagonalGaussianRegularizer(torch.nn.Module):
|
||||
z = posterior.mode()
|
||||
return z, None
|
||||
|
||||
class EmptyRegularizer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||
return z, None
|
||||
|
||||
class AbstractAutoencoder(torch.nn.Module):
|
||||
"""
|
||||
|
||||
@@ -5,8 +5,9 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
from typing import Optional
|
||||
from typing import Optional, Any, Callable, Union
|
||||
import logging
|
||||
import functools
|
||||
|
||||
from .diffusionmodules.util import AlphaBlender, timestep_embedding
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
@@ -17,23 +18,45 @@ if model_management.xformers_enabled():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
if model_management.sage_attention_enabled():
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
except ModuleNotFoundError as e:
|
||||
SAGE_ATTENTION_IS_AVAILABLE = False
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
SAGE_ATTENTION_IS_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
if model_management.sage_attention_enabled():
|
||||
if e.name == "sageattention":
|
||||
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||
else:
|
||||
raise e
|
||||
exit(-1)
|
||||
|
||||
if model_management.flash_attention_enabled():
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
FLASH_ATTENTION_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
if model_management.flash_attention_enabled():
|
||||
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||
exit(-1)
|
||||
|
||||
REGISTERED_ATTENTION_FUNCTIONS = {}
|
||||
def register_attention_function(name: str, func: Callable):
|
||||
# avoid replacing existing functions
|
||||
if name not in REGISTERED_ATTENTION_FUNCTIONS:
|
||||
REGISTERED_ATTENTION_FUNCTIONS[name] = func
|
||||
else:
|
||||
logging.warning(f"Attention function {name} already registered, skipping registration.")
|
||||
|
||||
def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]:
|
||||
if name == "optimized":
|
||||
return optimized_attention
|
||||
elif name not in REGISTERED_ATTENTION_FUNCTIONS:
|
||||
if default is ...:
|
||||
raise KeyError(f"Attention function {name} not found.")
|
||||
else:
|
||||
return default
|
||||
return REGISTERED_ATTENTION_FUNCTIONS[name]
|
||||
|
||||
from comfy.cli_args import args
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
@@ -91,7 +114,27 @@ class FeedForward(nn.Module):
|
||||
def Normalize(in_channels, dtype=None, device=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
|
||||
def wrap_attn(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
remove_attn_wrapper_key = False
|
||||
try:
|
||||
if "_inside_attn_wrapper" not in kwargs:
|
||||
transformer_options = kwargs.get("transformer_options", None)
|
||||
remove_attn_wrapper_key = True
|
||||
kwargs["_inside_attn_wrapper"] = True
|
||||
if transformer_options is not None:
|
||||
if "optimized_attention_override" in transformer_options:
|
||||
return transformer_options["optimized_attention_override"](func, *args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
if remove_attn_wrapper_key:
|
||||
del kwargs["_inside_attn_wrapper"]
|
||||
return wrapper
|
||||
|
||||
@wrap_attn
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||
|
||||
if skip_reshape:
|
||||
@@ -159,8 +202,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
@wrap_attn
|
||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
attn_precision = get_attn_precision(attn_precision, query.dtype)
|
||||
|
||||
if skip_reshape:
|
||||
@@ -230,7 +273,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||
return hidden_states
|
||||
|
||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
@wrap_attn
|
||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||
|
||||
if skip_reshape:
|
||||
@@ -359,7 +403,8 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
@wrap_attn
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
b = q.shape[0]
|
||||
dim_head = q.shape[-1]
|
||||
# check to make sure xformers isn't broken
|
||||
@@ -374,7 +419,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
disabled_xformers = True
|
||||
|
||||
if disabled_xformers:
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
|
||||
|
||||
if skip_reshape:
|
||||
# b h k d -> b k h d
|
||||
@@ -427,8 +472,8 @@ else:
|
||||
#TODO: other GPUs ?
|
||||
SDP_BATCH_LIMIT = 2**31
|
||||
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
@wrap_attn
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
@@ -470,8 +515,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||
return out
|
||||
|
||||
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout = "HND"
|
||||
@@ -501,7 +546,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
lambda t: t.transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
|
||||
if tensor_layout == "HND":
|
||||
if not skip_output_reshape:
|
||||
@@ -534,8 +579,8 @@ except AttributeError as error:
|
||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||
|
||||
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
@wrap_attn
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
@@ -555,7 +600,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
try:
|
||||
assert mask is None
|
||||
if mask is not None:
|
||||
raise RuntimeError("Mask must not be set for Flash attention")
|
||||
out = flash_attn_wrapper(
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2),
|
||||
@@ -597,6 +643,19 @@ else:
|
||||
|
||||
optimized_attention_masked = optimized_attention
|
||||
|
||||
|
||||
# register core-supported attention functions
|
||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("sage", attention_sage)
|
||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("flash", attention_flash)
|
||||
if model_management.xformers_enabled():
|
||||
register_attention_function("xformers", attention_xformers)
|
||||
register_attention_function("pytorch", attention_pytorch)
|
||||
register_attention_function("sub_quad", attention_sub_quad)
|
||||
register_attention_function("split", attention_split)
|
||||
|
||||
|
||||
def optimized_attention_for_device(device, mask=False, small_input=False):
|
||||
if small_input:
|
||||
if model_management.pytorch_attention_enabled():
|
||||
@@ -629,7 +688,7 @@ class CrossAttention(nn.Module):
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
def forward(self, x, context=None, value=None, mask=None, transformer_options={}):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
@@ -640,9 +699,9 @@ class CrossAttention(nn.Module):
|
||||
v = self.to_v(context)
|
||||
|
||||
if mask is None:
|
||||
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
else:
|
||||
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
@@ -746,7 +805,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
|
||||
n = self.attn1.to_out(n)
|
||||
else:
|
||||
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
||||
n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=transformer_options)
|
||||
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
@@ -786,7 +845,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
||||
n = self.attn2.to_out(n)
|
||||
else:
|
||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||
n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=transformer_options)
|
||||
|
||||
if "attn2_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn2_output_patch"]
|
||||
@@ -1017,7 +1076,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
|
||||
B, S, C = x_mix.shape
|
||||
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
|
||||
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
|
||||
x_mix = mix_block(x_mix, context=time_context, transformer_options=transformer_options)
|
||||
x_mix = rearrange(
|
||||
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
|
||||
)
|
||||
|
||||
@@ -606,7 +606,7 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||
return _block_mixing(*args, **kwargs)
|
||||
|
||||
|
||||
def _block_mixing(context, x, context_block, x_block, c):
|
||||
def _block_mixing(context, x, context_block, x_block, c, transformer_options={}):
|
||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||
|
||||
if x_block.x_block_self_attn:
|
||||
@@ -622,6 +622,7 @@ def _block_mixing(context, x, context_block, x_block, c):
|
||||
attn = optimized_attention(
|
||||
qkv[0], qkv[1], qkv[2],
|
||||
heads=x_block.attn.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
context_attn, x_attn = (
|
||||
attn[:, : context_qkv[0].shape[1]],
|
||||
@@ -637,6 +638,7 @@ def _block_mixing(context, x, context_block, x_block, c):
|
||||
attn2 = optimized_attention(
|
||||
x_qkv2[0], x_qkv2[1], x_qkv2[2],
|
||||
heads=x_block.attn2.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
|
||||
else:
|
||||
@@ -958,10 +960,10 @@ class MMDiT(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
|
||||
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
context = out["txt"]
|
||||
x = out["img"]
|
||||
else:
|
||||
@@ -970,6 +972,7 @@ class MMDiT(nn.Module):
|
||||
x,
|
||||
c=c_mod,
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
if control is not None:
|
||||
control_o = control.get("output")
|
||||
|
||||
@@ -145,7 +145,7 @@ class Downsample(nn.Module):
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d):
|
||||
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d, norm_op=Normalize):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
@@ -153,7 +153,7 @@ class ResnetBlock(nn.Module):
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.norm1 = norm_op(in_channels)
|
||||
self.conv1 = conv_op(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
@@ -162,7 +162,7 @@ class ResnetBlock(nn.Module):
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = ops.Linear(temb_channels,
|
||||
out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.norm2 = norm_op(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
||||
self.conv2 = conv_op(out_channels,
|
||||
out_channels,
|
||||
@@ -305,11 +305,11 @@ def vae_attention():
|
||||
return normal_attention
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels, conv_op=ops.Conv2d):
|
||||
def __init__(self, in_channels, conv_op=ops.Conv2d, norm_op=Normalize):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.norm = norm_op(in_channels)
|
||||
self.q = conv_op(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
|
||||
@@ -120,7 +120,7 @@ class Attention(nn.Module):
|
||||
nn.Dropout(0.0)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
@@ -146,7 +146,7 @@ class Attention(nn.Module):
|
||||
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||
|
||||
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
|
||||
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@@ -182,16 +182,16 @@ class OmniGen2TransformerBlock(nn.Module):
|
||||
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
|
||||
if self.modulation:
|
||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
|
||||
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
|
||||
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
|
||||
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
|
||||
hidden_states = hidden_states + self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
||||
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
||||
@@ -390,7 +390,7 @@ class OmniGen2Transformer2DModel(nn.Module):
|
||||
ref_img_sizes, img_sizes,
|
||||
)
|
||||
|
||||
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
|
||||
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, transformer_options={}):
|
||||
batch_size = len(hidden_states)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
@@ -405,17 +405,17 @@ class OmniGen2Transformer2DModel(nn.Module):
|
||||
shift += ref_img_len
|
||||
|
||||
for layer in self.noise_refiner:
|
||||
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
|
||||
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb, transformer_options=transformer_options)
|
||||
|
||||
if ref_image_hidden_states is not None:
|
||||
for layer in self.ref_image_refiner:
|
||||
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
|
||||
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb, transformer_options=transformer_options)
|
||||
|
||||
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
|
||||
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
_, _, H_padded, W_padded = hidden_states.shape
|
||||
@@ -444,7 +444,7 @@ class OmniGen2Transformer2DModel(nn.Module):
|
||||
)
|
||||
|
||||
for layer in self.context_refiner:
|
||||
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
|
||||
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options)
|
||||
|
||||
img_len = hidden_states.shape[1]
|
||||
combined_img_hidden_states = self.img_patch_embed_and_refine(
|
||||
@@ -453,13 +453,14 @@ class OmniGen2Transformer2DModel(nn.Module):
|
||||
noise_rotary_emb, ref_img_rotary_emb,
|
||||
l_effective_ref_img_len, l_effective_img_len,
|
||||
temb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
|
||||
attention_mask = None
|
||||
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
|
||||
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb, transformer_options=transformer_options)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
|
||||
|
||||
@@ -132,6 +132,7 @@ class Attention(nn.Module):
|
||||
encoder_hidden_states_mask: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
@@ -159,7 +160,7 @@ class Attention(nn.Module):
|
||||
joint_key = joint_key.flatten(start_dim=2)
|
||||
joint_value = joint_value.flatten(start_dim=2)
|
||||
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask)
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
|
||||
|
||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||
@@ -226,6 +227,7 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
encoder_hidden_states_mask: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
img_mod_params = self.img_mod(temb)
|
||||
txt_mod_params = self.txt_mod(temb)
|
||||
@@ -242,6 +244,7 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
encoder_hidden_states=txt_modulated,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + img_gate1 * img_attn_output
|
||||
@@ -434,9 +437,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"])
|
||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
else:
|
||||
@@ -446,11 +449,12 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i})
|
||||
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from einops import rearrange
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -34,7 +34,9 @@ class WanSelfAttention(nn.Module):
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
eps=1e-6, operation_settings={}):
|
||||
eps=1e-6,
|
||||
kv_dim=None,
|
||||
operation_settings={}):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -43,16 +45,18 @@ class WanSelfAttention(nn.Module):
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.eps = eps
|
||||
if kv_dim is None:
|
||||
kv_dim = dim
|
||||
|
||||
# layers
|
||||
self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.k = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.v = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, freqs):
|
||||
def forward(self, x, freqs, transformer_options={}):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||
@@ -60,21 +64,26 @@ class WanSelfAttention(nn.Module):
|
||||
"""
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
# query, key, value function
|
||||
def qkv_fn(x):
|
||||
def qkv_fn_q(x):
|
||||
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
||||
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
||||
v = self.v(x).view(b, s, n * d)
|
||||
return q, k, v
|
||||
return apply_rope1(q, freqs)
|
||||
|
||||
q, k, v = qkv_fn(x)
|
||||
q, k = apply_rope(q, k, freqs)
|
||||
def qkv_fn_k(x):
|
||||
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
||||
return apply_rope1(k, freqs)
|
||||
|
||||
#These two are VRAM hogs, so we want to do all of q computation and
|
||||
#have pytorch garbage collect the intermediates on the sub function
|
||||
#return before we touch k
|
||||
q = qkv_fn_q(x)
|
||||
k = qkv_fn_k(x)
|
||||
|
||||
x = optimized_attention(
|
||||
q.view(b, s, n * d),
|
||||
k.view(b, s, n * d),
|
||||
v,
|
||||
self.v(x).view(b, s, n * d),
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
x = self.o(x)
|
||||
@@ -83,7 +92,7 @@ class WanSelfAttention(nn.Module):
|
||||
|
||||
class WanT2VCrossAttention(WanSelfAttention):
|
||||
|
||||
def forward(self, x, context, **kwargs):
|
||||
def forward(self, x, context, transformer_options={}, **kwargs):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
@@ -95,7 +104,7 @@ class WanT2VCrossAttention(WanSelfAttention):
|
||||
v = self.v(context)
|
||||
|
||||
# compute attention
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads)
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
|
||||
|
||||
x = self.o(x)
|
||||
return x
|
||||
@@ -116,7 +125,7 @@ class WanI2VCrossAttention(WanSelfAttention):
|
||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, context, context_img_len):
|
||||
def forward(self, x, context, context_img_len, transformer_options={}):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
@@ -131,9 +140,9 @@ class WanI2VCrossAttention(WanSelfAttention):
|
||||
v = self.v(context)
|
||||
k_img = self.norm_k_img(self.k_img(context_img))
|
||||
v_img = self.v_img(context_img)
|
||||
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads)
|
||||
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options)
|
||||
# compute attention
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads)
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
|
||||
|
||||
# output
|
||||
x = x + img_x
|
||||
@@ -206,6 +215,7 @@ class WanAttentionBlock(nn.Module):
|
||||
freqs,
|
||||
context,
|
||||
context_img_len=257,
|
||||
transformer_options={},
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
@@ -224,12 +234,12 @@ class WanAttentionBlock(nn.Module):
|
||||
# self-attention
|
||||
y = self.self_attn(
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
freqs)
|
||||
freqs, transformer_options=transformer_options)
|
||||
|
||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||
|
||||
# cross-attention & ffn
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
return x
|
||||
@@ -396,6 +406,7 @@ class WanModel(torch.nn.Module):
|
||||
eps=1e-6,
|
||||
flf_pos_embed_token_number=None,
|
||||
in_dim_ref_conv=None,
|
||||
wan_attn_block_class=WanAttentionBlock,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
@@ -473,8 +484,8 @@ class WanModel(torch.nn.Module):
|
||||
# blocks
|
||||
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
||||
self.blocks = nn.ModuleList([
|
||||
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
@@ -559,12 +570,12 @@ class WanModel(torch.nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
@@ -742,17 +753,17 @@ class VaceWanModel(WanModel):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
|
||||
ii = self.vace_layers_mapping.get(i, None)
|
||||
if ii is not None:
|
||||
for iii in range(len(c)):
|
||||
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
x += c_skip * vace_strength[iii]
|
||||
del c_skip
|
||||
# head
|
||||
@@ -841,12 +852,12 @@ class CameraWanModel(WanModel):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
@@ -1319,3 +1330,247 @@ class WanModel_S2V(WanModel):
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
|
||||
class WanT2VCrossAttentionGather(WanSelfAttention):
|
||||
|
||||
def forward(self, x, context, transformer_options={}, **kwargs):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C] - video tokens
|
||||
context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536]
|
||||
"""
|
||||
b, n, d = x.size(0), self.num_heads, self.head_dim
|
||||
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(context))
|
||||
v = self.v(context)
|
||||
|
||||
# Handle audio temporal structure (16 tokens per frame)
|
||||
k = k.reshape(-1, 16, n, d).transpose(1, 2)
|
||||
v = v.reshape(-1, 16, n, d).transpose(1, 2)
|
||||
|
||||
# Handle video spatial structure
|
||||
q = q.reshape(k.shape[0], -1, n, d).transpose(1, 2)
|
||||
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
x = x.transpose(1, 2).view(b, -1, n, d).flatten(2)
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class AudioCrossAttentionWrapper(nn.Module):
|
||||
def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6, operation_settings={}):
|
||||
super().__init__()
|
||||
|
||||
self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm=qk_norm, kv_dim=kv_dim, eps=eps, operation_settings=operation_settings)
|
||||
self.norm1_audio = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x, audio, transformer_options={}):
|
||||
x = x + self.audio_cross_attn(self.norm1_audio(x), audio, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
class WanAttentionBlockAudio(WanAttentionBlock):
|
||||
|
||||
def __init__(self,
|
||||
cross_attn_type,
|
||||
dim,
|
||||
ffn_dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=False,
|
||||
eps=1e-6, operation_settings={}):
|
||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings)
|
||||
self.audio_cross_attn_wrapper = AudioCrossAttentionWrapper(dim, 1536, num_heads, qk_norm, eps, operation_settings=operation_settings)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
e,
|
||||
freqs,
|
||||
context,
|
||||
context_img_len=257,
|
||||
audio=None,
|
||||
transformer_options={},
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, C]
|
||||
e(Tensor): Shape [B, 6, C]
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
# assert e.dtype == torch.float32
|
||||
|
||||
if e.ndim < 4:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||
else:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
|
||||
# assert e[0].dtype == torch.float32
|
||||
|
||||
# self-attention
|
||||
y = self.self_attn(
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
freqs, transformer_options=transformer_options)
|
||||
|
||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||
|
||||
# cross-attention & ffn
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
if audio is not None:
|
||||
x = self.audio_cross_attn_wrapper(x, audio, transformer_options=transformer_options)
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
return x
|
||||
|
||||
class DummyAdapterLayer(nn.Module):
|
||||
def __init__(self, layer):
|
||||
super().__init__()
|
||||
self.layer = layer
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.layer(*args, **kwargs)
|
||||
|
||||
|
||||
class AudioProjModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
seq_len=5,
|
||||
blocks=13, # add a new parameter blocks
|
||||
channels=768, # add a new parameter channels
|
||||
intermediate_dim=512,
|
||||
output_dim=1536,
|
||||
context_tokens=16,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.seq_len = seq_len
|
||||
self.blocks = blocks
|
||||
self.channels = channels
|
||||
self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels.
|
||||
self.intermediate_dim = intermediate_dim
|
||||
self.context_tokens = context_tokens
|
||||
self.output_dim = output_dim
|
||||
|
||||
# define multiple linear layers
|
||||
self.audio_proj_glob_1 = DummyAdapterLayer(operations.Linear(self.input_dim, intermediate_dim, dtype=dtype, device=device))
|
||||
self.audio_proj_glob_2 = DummyAdapterLayer(operations.Linear(intermediate_dim, intermediate_dim, dtype=dtype, device=device))
|
||||
self.audio_proj_glob_3 = DummyAdapterLayer(operations.Linear(intermediate_dim, context_tokens * output_dim, dtype=dtype, device=device))
|
||||
|
||||
self.audio_proj_glob_norm = DummyAdapterLayer(operations.LayerNorm(output_dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, audio_embeds):
|
||||
video_length = audio_embeds.shape[1]
|
||||
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
||||
batch_size, window_size, blocks, channels = audio_embeds.shape
|
||||
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
||||
|
||||
audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds))
|
||||
audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds))
|
||||
|
||||
context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim)
|
||||
|
||||
context_tokens = self.audio_proj_glob_norm(context_tokens)
|
||||
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
|
||||
|
||||
return context_tokens
|
||||
|
||||
|
||||
class HumoWanModel(WanModel):
|
||||
r"""
|
||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_type='humo',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
flf_pos_embed_token_number=None,
|
||||
image_model=None,
|
||||
audio_token_num=16,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
|
||||
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
freqs=None,
|
||||
audio_embed=None,
|
||||
reference_latent=None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
):
|
||||
bs, _, time, height, width = x.shape
|
||||
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
if reference_latent is not None:
|
||||
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
|
||||
ref = ref.flatten(2).transpose(1, 2)
|
||||
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=time, device=x.device, dtype=x.dtype)
|
||||
x = torch.cat([x, ref], dim=1)
|
||||
freqs = torch.cat([freqs, freqs_ref], dim=1)
|
||||
del ref, freqs_ref
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
context_img_len = None
|
||||
|
||||
if audio_embed is not None:
|
||||
audio = self.audio_proj(audio_embed).permute(0, 3, 1, 2).flatten(2).transpose(1, 2)
|
||||
else:
|
||||
audio = None
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, audio=audio, transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, audio=audio, transformer_options=transformer_options)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
@@ -297,6 +297,12 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
if isinstance(model, comfy.model_base.Omnigen2):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
if isinstance(model, comfy.model_base.QwenImage):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"): #QwenImage lora format
|
||||
|
||||
@@ -42,6 +42,7 @@ import comfy.ldm.wan.model
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
import comfy.ldm.chroma_radiance.model
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
@@ -1212,6 +1213,46 @@ class WAN21_Camera(WAN21):
|
||||
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
|
||||
return out
|
||||
|
||||
class WAN21_HuMo(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.HumoWanModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
noise = kwargs.get("noise", None)
|
||||
|
||||
audio_embed = kwargs.get("audio_embed", None)
|
||||
if audio_embed is not None:
|
||||
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
|
||||
|
||||
if "c_concat" not in out: # 1.7B model
|
||||
reference_latents = kwargs.get("reference_latents", None)
|
||||
if reference_latents is not None:
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
|
||||
else:
|
||||
noise_shape = list(noise.shape)
|
||||
noise_shape[1] += 4
|
||||
concat_latent = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
|
||||
zero_vae_values_first = torch.tensor([0.8660, -0.4326, -0.0017, -0.4884, -0.5283, 0.9207, -0.9896, 0.4433, -0.5543, -0.0113, 0.5753, -0.6000, -0.8346, -0.3497, -0.1926, -0.6938]).view(1, 16, 1, 1, 1)
|
||||
zero_vae_values_second = torch.tensor([1.0869, -1.2370, 0.0206, -0.4357, -0.6411, 2.0307, -1.5972, 1.2659, -0.8595, -0.4654, 0.9638, -1.6330, -1.4310, -0.1098, -0.3856, -1.4583]).view(1, 16, 1, 1, 1)
|
||||
zero_vae_values = torch.tensor([0.8642, -1.8583, 0.1577, 0.1350, -0.3641, 2.5863, -1.9670, 1.6065, -1.0475, -0.8678, 1.1734, -1.8138, -1.5933, -0.7721, -0.3289, -1.3745]).view(1, 16, 1, 1, 1)
|
||||
concat_latent[:, 4:] = zero_vae_values
|
||||
concat_latent[:, 4:, :1] = zero_vae_values_first
|
||||
concat_latent[:, 4:, 1:2] = zero_vae_values_second
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(concat_latent)
|
||||
reference_latents = kwargs.get("reference_latents", None)
|
||||
if reference_latents is not None:
|
||||
ref_latent = self.process_latent_in(reference_latents[-1])
|
||||
ref_latent_shape = list(ref_latent.shape)
|
||||
ref_latent_shape[1] += 4 + ref_latent_shape[1]
|
||||
ref_latent_full = torch.zeros(ref_latent_shape, device=ref_latent.device, dtype=ref_latent.dtype)
|
||||
ref_latent_full[:, 20:] = ref_latent
|
||||
ref_latent_full[:, 16:20] = 1.0
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent_full)
|
||||
|
||||
return out
|
||||
|
||||
class WAN22_S2V(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
||||
@@ -1320,8 +1361,8 @@ class HiDream(BaseModel):
|
||||
return out
|
||||
|
||||
class Chroma(Flux):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.chroma.model.Chroma):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@@ -1331,6 +1372,10 @@ class Chroma(Flux):
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma_radiance.model.ChromaRadiance)
|
||||
|
||||
class ACEStep(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel)
|
||||
@@ -1432,3 +1477,31 @@ class HunyuanImage21(BaseModel):
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
|
||||
return out
|
||||
|
||||
class HunyuanImage21Refiner(HunyuanImage21):
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise_augmentation = kwargs.get("noise_augmentation", 0.0)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
shape_image = list(noise.shape)
|
||||
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
|
||||
else:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
image = self.process_latent_in(image)
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
if noise_augmentation > 0:
|
||||
generator = torch.Generator(device="cpu")
|
||||
generator.manual_seed(kwargs.get("seed", 0) - 10)
|
||||
noise = torch.randn(image.shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device)
|
||||
image = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image
|
||||
else:
|
||||
image = 0.75 * image
|
||||
return image
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
out['disable_time_r'] = comfy.conds.CONDConstant(True)
|
||||
return out
|
||||
|
||||
@@ -174,7 +174,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
||||
return dit_config
|
||||
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "flux"
|
||||
dit_config["in_channels"] = 16
|
||||
@@ -204,6 +204,18 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["out_dim"] = 3072
|
||||
dit_config["hidden_dim"] = 5120
|
||||
dit_config["n_layers"] = 5
|
||||
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
|
||||
dit_config["image_model"] = "chroma_radiance"
|
||||
dit_config["in_channels"] = 3
|
||||
dit_config["out_channels"] = 3
|
||||
dit_config["patch_size"] = 16
|
||||
dit_config["nerf_hidden_size"] = 64
|
||||
dit_config["nerf_mlp_ratio"] = 4
|
||||
dit_config["nerf_depth"] = 4
|
||||
dit_config["nerf_max_freqs"] = 8
|
||||
dit_config["nerf_tile_size"] = 32
|
||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
return dit_config
|
||||
@@ -390,6 +402,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["model_type"] = "camera_2.2"
|
||||
elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "s2v"
|
||||
elif '{}audio_proj.audio_proj_glob_1.layer.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "humo"
|
||||
else:
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "i2v"
|
||||
|
||||
16
comfy/pixel_space_convert.py
Normal file
16
comfy/pixel_space_convert.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import torch
|
||||
|
||||
|
||||
# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1
|
||||
# to LATENT B, C, H, W and values on the scale of -1..1.
|
||||
class PixelspaceConversionVAE(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pixel_space_vae = torch.nn.Parameter(torch.tensor(1.0))
|
||||
|
||||
def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
|
||||
return pixels
|
||||
|
||||
def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
|
||||
return samples
|
||||
|
||||
39
comfy/sd.py
39
comfy/sd.py
@@ -18,6 +18,7 @@ import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.hunyuan_video.vae
|
||||
import comfy.pixel_space_convert
|
||||
import yaml
|
||||
import math
|
||||
import os
|
||||
@@ -285,6 +286,7 @@ class VAE:
|
||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
self.disable_offload = False
|
||||
self.not_video = False
|
||||
|
||||
self.downscale_index_formula = None
|
||||
self.upscale_index_formula = None
|
||||
@@ -409,6 +411,23 @@ class VAE:
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||
self.downscale_index_formula = (8, 32, 32)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
self.latent_channels = 64
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
|
||||
self.downscale_index_formula = (4, 16, 16)
|
||||
self.latent_dim = 3
|
||||
self.not_video = True
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"},
|
||||
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
|
||||
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||
elif "decoder.conv_in.conv.weight" in sd:
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
ddconfig["conv3d"] = True
|
||||
@@ -498,6 +517,15 @@ class VAE:
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.disable_offload = True
|
||||
self.extra_1d_channel = 16
|
||||
elif "pixel_space_vae" in sd:
|
||||
self.first_stage_model = comfy.pixel_space_convert.PixelspaceConversionVAE()
|
||||
self.memory_used_encode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (1 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.downscale_ratio = 1
|
||||
self.upscale_ratio = 1
|
||||
self.latent_channels = 3
|
||||
self.latent_dim = 2
|
||||
self.output_channels = 3
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@@ -670,7 +698,10 @@ class VAE:
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
if not self.not_video:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
else:
|
||||
pixel_samples = pixel_samples.unsqueeze(2)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
@@ -704,7 +735,10 @@ class VAE:
|
||||
dims = self.latent_dim
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
if dims == 3:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
if not self.not_video:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
else:
|
||||
pixel_samples = pixel_samples.unsqueeze(2)
|
||||
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
@@ -761,6 +795,7 @@ class VAE:
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
class StyleModel:
|
||||
def __init__(self, model, device="cpu"):
|
||||
self.model = model
|
||||
|
||||
@@ -1073,6 +1073,16 @@ class WAN21_Vace(WAN21_T2V):
|
||||
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN21_HuMo(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "humo",
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_HuMo(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN22_S2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@@ -1205,6 +1215,19 @@ class Chroma(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
unet_config = {
|
||||
"image_model": "chroma_radiance",
|
||||
}
|
||||
|
||||
latent_format = comfy.latent_formats.ChromaRadiance
|
||||
|
||||
# Pixel-space model, no spatial compression for model input.
|
||||
memory_usage_factor = 0.038
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.ChromaRadiance(self, device=device)
|
||||
|
||||
class ACEStep(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"audio_model": "ace",
|
||||
@@ -1321,6 +1344,23 @@ class HunyuanImage21(HunyuanVideo):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
||||
class HunyuanImage21Refiner(HunyuanVideo):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
"patch_size": [1, 1, 1],
|
||||
"vec_in_dim": None,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 4.0,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.HunyuanImage21Refiner
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanImage21Refiner(self, device=device)
|
||||
return out
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -22,17 +22,14 @@ class HunyuanImageTokenizer(QwenImageTokenizer):
|
||||
|
||||
# ByT5 processing for HunyuanImage
|
||||
text_prompt_texts = []
|
||||
pattern_quote_single = r'\'(.*?)\''
|
||||
pattern_quote_double = r'\"(.*?)\"'
|
||||
pattern_quote_chinese_single = r'‘(.*?)’'
|
||||
pattern_quote_chinese_double = r'“(.*?)”'
|
||||
|
||||
matches_quote_single = re.findall(pattern_quote_single, text)
|
||||
matches_quote_double = re.findall(pattern_quote_double, text)
|
||||
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text)
|
||||
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text)
|
||||
|
||||
text_prompt_texts.extend(matches_quote_single)
|
||||
text_prompt_texts.extend(matches_quote_double)
|
||||
text_prompt_texts.extend(matches_quote_chinese_single)
|
||||
text_prompt_texts.extend(matches_quote_chinese_double)
|
||||
|
||||
Reference in New Issue
Block a user