Simplify the warp_gemm definitions in GetQKBlockGemm and GetKVBlockGemm

This commit is contained in:
Qianfeng Zhang
2025-09-25 15:38:55 +00:00
parent bd32cc0de0
commit 27b96b15c4

View File

@@ -3,7 +3,6 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
@@ -583,53 +582,51 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
typename Problem::HstuAttentionTileSetting::Gemm0BlockWarps,
typename Problem::HstuAttentionTileSetting::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM =
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{});
constexpr index_t WarpGemmK =
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{});
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
auto warp_gemm = [&]() {
if constexpr((std::is_same_v<typename Problem::QKVDataType, half_t> ||
std::is_same_v<typename Problem::QKVDataType, bf16_t>)&&std::
is_same_v<typename Problem::GemmAccDataType, float>)
{
constexpr index_t WarpGemmM =
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{});
constexpr index_t WarpGemmK =
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{});
if constexpr(std::is_same_v<typename Problem::QKVDataType, half_t> &&
std::is_same_v<typename Problem::GemmAccDataType, float>)
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
{
if constexpr(WarpGemmK == 32)
return WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>{};
else
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
}
else // WarpGemmM == 4
return WarpGemmMfmaF16F16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::QKVDataType, bf16_t> &&
std::is_same_v<typename Problem::GemmAccDataType, float>)
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
{
if constexpr(WarpGemmK == 32)
return WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>{};
else
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
}
else // WarpGemmM == 4
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::QKVDataType, fp8_t> &&
std::is_same_v<typename Problem::GemmAccDataType, float>)
{
static_assert(WarpGemmM == 32);
#ifdef __gfx950__
static_assert(WarpGemmM == 16 && WarpGemmK == 32, "Not supported WarpGemm sizes!");
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr index_t swizzle_factor = 4;
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
swizzle_factor>{};
} // TODO - bf8_t
return WarpGemmDispatcher<
typename Problem::QKVDataType,
typename Problem::QKVDataType,
typename Problem::GemmAccDataType,
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}),
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}),
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}),
true,
false,
false,
WGAttrNumAccessEnum::Single>{};
#else
static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)),
"Not supported WarpGemm sizes!");
return WarpGemmDispatcher<
typename Problem::QKVDataType,
typename Problem::QKVDataType,
typename Problem::GemmAccDataType,
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}),
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}),
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}),
true,
false,
false,
WGAttrNumAccessEnum::Single>{};
#endif
}
else
{
static_assert(false, "Not supported data types!");
}
}();
using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy<
@@ -667,25 +664,46 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
typename Problem::HstuAttentionTileSetting::Gemm1WarpTile>>;
auto warp_gemm = [&]() {
if constexpr(std::is_same_v<typename Problem::QKVDataType, fp8_t> &&
std::is_same_v<typename Problem::GemmAccDataType, float>)
if constexpr((std::is_same_v<typename Problem::QKVDataType, half_t> ||
std::is_same_v<typename Problem::QKVDataType, bf16_t>)&&std::
is_same_v<typename Problem::GemmAccDataType, float>)
{
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<>{};
// return
// WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
// Problem::PDataType, typename Problem::VDataType>>>{};
constexpr index_t WarpGemmM =
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{});
constexpr index_t WarpGemmK =
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{});
static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)),
"Not supported WarpGemm sizes!");
if constexpr(WarpGemmK == 32)
return WarpGemmDispatcher<
typename Problem::QKVDataType,
typename Problem::QKVDataType,
typename Problem::GemmAccDataType,
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}),
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}),
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}),
true,
false,
false,
WGAttrNumAccessEnum::Double>{};
else
return WarpGemmDispatcher<
typename Problem::QKVDataType,
typename Problem::QKVDataType,
typename Problem::GemmAccDataType,
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}),
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}),
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}),
true,
false,
false,
WGAttrNumAccessEnum::Single>{};
}
else
{
return WarpGemmDispatcher<
typename Problem::QKVDataType,
typename Problem::QKVDataType,
typename Problem::GemmAccDataType,
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}),
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}),
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}),
true>{};
static_assert(false, "Not supported data types!");
}
}();