mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
Simplify the warp_gemm definitions in GetQKBlockGemm and GetKVBlockGemm
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
@@ -583,53 +582,51 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
typename Problem::HstuAttentionTileSetting::Gemm0BlockWarps,
|
||||
typename Problem::HstuAttentionTileSetting::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
constexpr index_t WarpGemmM =
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{});
|
||||
constexpr index_t WarpGemmK =
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{});
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
auto warp_gemm = [&]() {
|
||||
if constexpr((std::is_same_v<typename Problem::QKVDataType, half_t> ||
|
||||
std::is_same_v<typename Problem::QKVDataType, bf16_t>)&&std::
|
||||
is_same_v<typename Problem::GemmAccDataType, float>)
|
||||
{
|
||||
constexpr index_t WarpGemmM =
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{});
|
||||
constexpr index_t WarpGemmK =
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{});
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::QKVDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::GemmAccDataType, float>)
|
||||
{
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
{
|
||||
if constexpr(WarpGemmK == 32)
|
||||
return WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>{};
|
||||
else
|
||||
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
|
||||
}
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaF16F16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QKVDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::GemmAccDataType, float>)
|
||||
{
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
{
|
||||
if constexpr(WarpGemmK == 32)
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>{};
|
||||
else
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
|
||||
}
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QKVDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::GemmAccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 32);
|
||||
#ifdef __gfx950__
|
||||
static_assert(WarpGemmM == 16 && WarpGemmK == 32, "Not supported WarpGemm sizes!");
|
||||
|
||||
// TODO: hard coded here. Otherwise, it may incorrect result
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
} // TODO - bf8_t
|
||||
return WarpGemmDispatcher<
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
WGAttrNumAccessEnum::Single>{};
|
||||
#else
|
||||
static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)),
|
||||
"Not supported WarpGemm sizes!");
|
||||
|
||||
return WarpGemmDispatcher<
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
WGAttrNumAccessEnum::Single>{};
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Not supported data types!");
|
||||
}
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy<
|
||||
@@ -667,25 +664,46 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
typename Problem::HstuAttentionTileSetting::Gemm1WarpTile>>;
|
||||
|
||||
auto warp_gemm = [&]() {
|
||||
if constexpr(std::is_same_v<typename Problem::QKVDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::GemmAccDataType, float>)
|
||||
if constexpr((std::is_same_v<typename Problem::QKVDataType, half_t> ||
|
||||
std::is_same_v<typename Problem::QKVDataType, bf16_t>)&&std::
|
||||
is_same_v<typename Problem::GemmAccDataType, float>)
|
||||
{
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<>{};
|
||||
// return
|
||||
// WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
|
||||
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
|
||||
// Problem::PDataType, typename Problem::VDataType>>>{};
|
||||
constexpr index_t WarpGemmM =
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{});
|
||||
constexpr index_t WarpGemmK =
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{});
|
||||
|
||||
static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)),
|
||||
"Not supported WarpGemm sizes!");
|
||||
|
||||
if constexpr(WarpGemmK == 32)
|
||||
return WarpGemmDispatcher<
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
WGAttrNumAccessEnum::Double>{};
|
||||
else
|
||||
return WarpGemmDispatcher<
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
WGAttrNumAccessEnum::Single>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return WarpGemmDispatcher<
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}),
|
||||
true>{};
|
||||
static_assert(false, "Not supported data types!");
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user