mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-19 22:19:02 +00:00
Fix cogvideox dtypes and ops.
This commit is contained in:
@@ -378,7 +378,7 @@ class CogVideoXTransformer3DModel(nn.Module):
|
||||
temporal_interpolation_scale=temporal_interpolation_scale,
|
||||
use_positional_embeddings=not use_rotary_positional_embeddings,
|
||||
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
||||
device=device, dtype=torch.float32, operations=operations,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
# 2. Time embedding
|
||||
|
||||
@@ -80,7 +80,7 @@ class SpatialNorm3D(nn.Module):
|
||||
"""Spatially conditioned normalization."""
|
||||
def __init__(self, f_channels, zq_channels, groups=32):
|
||||
super().__init__()
|
||||
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
|
||||
self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
|
||||
self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
|
||||
@@ -115,8 +115,8 @@ class ResnetBlock3D(nn.Module):
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
if spatial_norm_dim is None:
|
||||
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
||||
self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
|
||||
self.norm1 = ops.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
||||
self.norm2 = ops.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
|
||||
else:
|
||||
self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups)
|
||||
self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups)
|
||||
@@ -124,7 +124,7 @@ class ResnetBlock3D(nn.Module):
|
||||
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = nn.Linear(temb_channels, out_channels)
|
||||
self.temb_proj = ops.Linear(temb_channels, out_channels)
|
||||
|
||||
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
@@ -167,7 +167,7 @@ class Downsample3D(nn.Module):
|
||||
"""3D downsampling with optional temporal compression."""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
@@ -197,7 +197,7 @@ class Upsample3D(nn.Module):
|
||||
"""3D upsampling with optional temporal decompression."""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
@@ -332,7 +332,7 @@ class Encoder3D(nn.Module):
|
||||
num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode,
|
||||
)
|
||||
|
||||
self.norm_out = nn.GroupNorm(groups, block_out_channels[-1], eps=1e-6)
|
||||
self.norm_out = ops.GroupNorm(groups, block_out_channels[-1], eps=1e-6)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user