Add support for WarpGem-16x16x32 in QK-BlockGemm (which enables using ds_write/read_b128 for K

This commit is contained in:
Qianfeng Zhang
2025-04-25 06:06:50 +00:00
parent a41371f734
commit 05910ebe0b

View File

@@ -364,6 +364,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
constexpr index_t WarpGemmK = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{});
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::QKVDataType, half_t> &&
@@ -372,7 +373,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
{
if constexpr(WarpGemmK == 32)
return WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{};
else
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
}
else // WarpGemmM == 4
return WarpGemmMfmaF16F16F32M4N64K16{};
}
@@ -382,7 +388,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
{
if constexpr(WarpGemmK == 32)
return WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{};
else
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
}
else // WarpGemmM == 4
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
}