diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index 6ee53005a0..df1a5cb9c8 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -364,6 +364,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr auto warp_gemm = []() { constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + constexpr index_t WarpGemmK = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}); static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); if constexpr(std::is_same_v && @@ -372,7 +373,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy if constexpr(WarpGemmM == 32) return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) - return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; + { + if constexpr(WarpGemmK == 32) + return WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}; + else + return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; + } else // WarpGemmM == 4 return WarpGemmMfmaF16F16F32M4N64K16{}; } @@ -382,7 +388,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy if constexpr(WarpGemmM == 32) return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) - return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; + { + if constexpr(WarpGemmK == 32) + return WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}; + else + return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; + } else // WarpGemmM == 4 return WarpGemmMfmaBf16Bf16F32M4N64K16{}; }