mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Add support for WarpGem-16x16x32 in QK-BlockGemm (which enables using ds_write/read_b128 for K
This commit is contained in:
@@ -364,6 +364,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
constexpr index_t WarpGemmK = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{});
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::QKVDataType, half_t> &&
|
||||
@@ -372,7 +373,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
|
||||
{
|
||||
if constexpr(WarpGemmK == 32)
|
||||
return WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{};
|
||||
else
|
||||
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
|
||||
}
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaF16F16F32M4N64K16{};
|
||||
}
|
||||
@@ -382,7 +388,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
|
||||
{
|
||||
if constexpr(WarpGemmK == 32)
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{};
|
||||
else
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
|
||||
}
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user