Fix cogvideox dtypes and ops.

This commit is contained in:
Talmaj Marinc
2026-04-14 15:58:34 +02:00
parent c8a843e240
commit dff15d7e5f
2 changed files with 8 additions and 8 deletions

View File

@@ -378,7 +378,7 @@ class CogVideoXTransformer3DModel(nn.Module):
temporal_interpolation_scale=temporal_interpolation_scale,
use_positional_embeddings=not use_rotary_positional_embeddings,
use_learned_positional_embeddings=use_learned_positional_embeddings,
device=device, dtype=torch.float32, operations=operations,
device=device, dtype=dtype, operations=operations,
)
# 2. Time embedding

View File

@@ -80,7 +80,7 @@ class SpatialNorm3D(nn.Module):
"""Spatially conditioned normalization."""
def __init__(self, f_channels, zq_channels, groups=32):
super().__init__()
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
@@ -115,8 +115,8 @@ class ResnetBlock3D(nn.Module):
self.nonlinearity = nn.SiLU()
if spatial_norm_dim is None:
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
self.norm1 = ops.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
self.norm2 = ops.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
else:
self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups)
self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups)
@@ -124,7 +124,7 @@ class ResnetBlock3D(nn.Module):
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
if temb_channels > 0:
self.temb_proj = nn.Linear(temb_channels, out_channels)
self.temb_proj = ops.Linear(temb_channels, out_channels)
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
@@ -167,7 +167,7 @@ class Downsample3D(nn.Module):
"""3D downsampling with optional temporal compression."""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, x):
@@ -197,7 +197,7 @@ class Upsample3D(nn.Module):
"""3D upsampling with optional temporal decompression."""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, x):
@@ -332,7 +332,7 @@ class Encoder3D(nn.Module):
num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode,
)
self.norm_out = nn.GroupNorm(groups, block_out_channels[-1], eps=1e-6)
self.norm_out = ops.GroupNorm(groups, block_out_channels[-1], eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode)