Support fp16 for Cosmos-Predict2 and Anima (#12249)

This commit is contained in:
tdrussell
2026-02-06 19:12:15 -06:00
committed by GitHub
parent 204e65b8dc
commit 6a26328842
2 changed files with 19 additions and 9 deletions

View File

@@ -335,7 +335,7 @@ class FinalLayer(nn.Module):
device=None, dtype=None, operations=None
):
super().__init__()
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = operations.Linear(
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
)
@@ -463,6 +463,8 @@ class Block(nn.Module):
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
residual_dtype = x_B_T_H_W_D.dtype
compute_dtype = emb_B_T_D.dtype
if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
@@ -512,7 +514,7 @@ class Block(nn.Module):
result_B_T_H_W_D = rearrange(
self.self_attn(
# normalized_x_B_T_HW_D,
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
None,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
@@ -522,7 +524,7 @@ class Block(nn.Module):
h=H,
w=W,
)
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
def _x_fn(
_x_B_T_H_W_D: torch.Tensor,
@@ -536,7 +538,7 @@ class Block(nn.Module):
)
_result_B_T_H_W_D = rearrange(
self.cross_attn(
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
@@ -555,7 +557,7 @@ class Block(nn.Module):
shift_cross_attn_B_T_1_1_D,
transformer_options=transformer_options,
)
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
normalized_x_B_T_H_W_D = _fn(
x_B_T_H_W_D,
@@ -563,8 +565,8 @@ class Block(nn.Module):
scale_mlp_B_T_1_1_D,
shift_mlp_B_T_1_1_D,
)
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
return x_B_T_H_W_D
@@ -876,6 +878,14 @@ class MiniTrainDIT(nn.Module):
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
"transformer_options": kwargs.get("transformer_options", {}),
}
# The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
# in fp32, but run attention and MLP modules in fp16.
# An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
# quality degradation and visual artifacts.
if x_B_T_H_W_D.dtype == torch.float16:
x_B_T_H_W_D = x_B_T_H_W_D.float()
for block in self.blocks:
x_B_T_H_W_D = block(
x_B_T_H_W_D,

View File

@@ -993,7 +993,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
memory_usage_factor = 1.0
supported_inference_dtypes = [torch.bfloat16, torch.float32]
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)
@@ -1023,7 +1023,7 @@ class Anima(supported_models_base.BASE):
memory_usage_factor = 1.0
supported_inference_dtypes = [torch.bfloat16, torch.float32]
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)