From 05910ebe0bc45c8fc7d724d55f5ba07d4b3201c9 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 25 Apr 2025 06:06:50 +0000 Subject: [PATCH] Add support for WarpGem-16x16x32 in QK-BlockGemm (which enables using ds_write/read_b128 for K --- ...hstu_attention_fwd_pipeline_default_policy.hpp | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index 6ee53005a0..df1a5cb9c8 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -364,6 +364,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr auto warp_gemm = []() { constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + constexpr index_t WarpGemmK = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}); static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); if constexpr(std::is_same_v && @@ -372,7 +373,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy if constexpr(WarpGemmM == 32) return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) - return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; + { + if constexpr(WarpGemmK == 32) + return WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}; + else + return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; + } else // WarpGemmM == 4 return WarpGemmMfmaF16F16F32M4N64K16{}; } @@ -382,7 +388,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy if constexpr(WarpGemmM == 32) return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; else if constexpr(WarpGemmM == 16) - return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; + { + if constexpr(WarpGemmK == 32) + return WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}; + else + return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; + } else // WarpGemmM == 4 return WarpGemmMfmaBf16Bf16F32M4N64K16{}; }