diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index c270e6333..6491e486b 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -335,7 +335,7 @@ class FinalLayer(nn.Module): device=None, dtype=None, operations=None ): super().__init__() - self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = operations.Linear( hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype ) @@ -463,6 +463,8 @@ class Block(nn.Module): extra_per_block_pos_emb: Optional[torch.Tensor] = None, transformer_options: Optional[dict] = {}, ) -> torch.Tensor: + residual_dtype = x_B_T_H_W_D.dtype + compute_dtype = emb_B_T_D.dtype if extra_per_block_pos_emb is not None: x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb @@ -512,7 +514,7 @@ class Block(nn.Module): result_B_T_H_W_D = rearrange( self.self_attn( # normalized_x_B_T_HW_D, - rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), + rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"), None, rope_emb=rope_emb_L_1_1_D, transformer_options=transformer_options, @@ -522,7 +524,7 @@ class Block(nn.Module): h=H, w=W, ) - x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D + x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) def _x_fn( _x_B_T_H_W_D: torch.Tensor, @@ -536,7 +538,7 @@ class Block(nn.Module): ) _result_B_T_H_W_D = rearrange( self.cross_attn( - rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), + rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"), crossattn_emb, rope_emb=rope_emb_L_1_1_D, transformer_options=transformer_options, @@ -555,7 +557,7 @@ class Block(nn.Module): shift_cross_attn_B_T_1_1_D, transformer_options=transformer_options, ) - x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D + x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D normalized_x_B_T_H_W_D = _fn( x_B_T_H_W_D, @@ -563,8 +565,8 @@ class Block(nn.Module): scale_mlp_B_T_1_1_D, shift_mlp_B_T_1_1_D, ) - result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D) - x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D + result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype)) + x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) return x_B_T_H_W_D @@ -876,6 +878,14 @@ class MiniTrainDIT(nn.Module): "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "transformer_options": kwargs.get("transformer_options", {}), } + + # The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream + # in fp32, but run attention and MLP modules in fp16. + # An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable + # quality degradation and visual artifacts. + if x_B_T_H_W_D.dtype == torch.float16: + x_B_T_H_W_D = x_B_T_H_W_D.float() + for block in self.blocks: x_B_T_H_W_D = block( x_B_T_H_W_D, diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 77264ed28..56a21b0ef 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -993,7 +993,7 @@ class CosmosT2IPredict2(supported_models_base.BASE): memory_usage_factor = 1.0 - supported_inference_dtypes = [torch.bfloat16, torch.float32] + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] def __init__(self, unet_config): super().__init__(unet_config) @@ -1023,7 +1023,7 @@ class Anima(supported_models_base.BASE): memory_usage_factor = 1.0 - supported_inference_dtypes = [torch.bfloat16, torch.float32] + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] def __init__(self, unet_config): super().__init__(unet_config)