mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Small refactor (#1246)
* Remove kIsFp8 * Extract alias * Fix K, V and corresponding acc type --------- Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -27,13 +27,12 @@ struct FmhaFwdKernel
|
||||
static_assert(kBlockPerCu > 0);
|
||||
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
|
||||
|
||||
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
|
||||
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
|
||||
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
|
||||
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
|
||||
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
|
||||
static constexpr bool kIsFp8 = FmhaPipeline::kIsFp8;
|
||||
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
|
||||
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
|
||||
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
|
||||
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
|
||||
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
|
||||
|
||||
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
|
||||
|
||||
|
||||
@@ -49,13 +49,6 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr bool kStoreLSE = Traits::kStoreLSE;
|
||||
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
static constexpr bool kIsFp8 =
|
||||
(std::is_same_v<QDataType, fp8_t> || std::is_same_v<QDataType, bf8_t>)&&(
|
||||
std::is_same_v<KDataType, fp8_t> ||
|
||||
std::is_same_v<KDataType, bf8_t>)&&(std::is_same_v<VDataType, fp8_t> ||
|
||||
std::is_same_v<VDataType, bf8_t>)&&std::
|
||||
is_same_v<SaccDataType, float> &&
|
||||
std::is_same_v<OaccDataType, float>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -31,7 +31,6 @@ struct BlockFmhaPipelineQRKSVS
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
|
||||
static_assert(kQLoadOnce == Policy::QLoadOnce);
|
||||
static constexpr bool kIsFp8 = Problem::kIsFp8;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
|
||||
@@ -32,7 +32,6 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
|
||||
static_assert(kQLoadOnce == Policy::QLoadOnce);
|
||||
static constexpr bool kIsFp8 = Problem::kIsFp8;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
|
||||
static_assert(kQLoadOnce == Policy::QLoadOnce);
|
||||
static constexpr bool kIsFp8 = Problem::kIsFp8;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
|
||||
@@ -30,7 +30,6 @@ struct BlockFmhaPipelineQSKSVS
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
static constexpr bool kQLoadOnce = false;
|
||||
static_assert(kQLoadOnce == Policy::QLoadOnce);
|
||||
static constexpr bool kIsFp8 = Problem::kIsFp8;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
|
||||
@@ -97,16 +97,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
{
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{};
|
||||
}
|
||||
else if constexpr(Problem::kIsFp8)
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
|
||||
return WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::QDataType,
|
||||
typename Problem::KDataType>,
|
||||
2,
|
||||
swizzle_factor>>{};
|
||||
}
|
||||
// TODO: hard coded here. Otherwise, it may incorrect result
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
} // TODO - bf8_t
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy =
|
||||
@@ -221,16 +220,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
{
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{};
|
||||
}
|
||||
else if constexpr(Problem::kIsFp8)
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
|
||||
return WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::QDataType,
|
||||
typename Problem::KDataType>,
|
||||
2,
|
||||
swizzle_factor>>{};
|
||||
}
|
||||
// TODO: hard coded here. Otherwise, it may incorrect result
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
} // TODO - bf8_t
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy =
|
||||
@@ -920,12 +918,11 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
Problem::BlockFmhaShape::kK1>>;
|
||||
|
||||
auto warp_gemm = [&]() {
|
||||
if constexpr(Problem::kIsFp8)
|
||||
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::VDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::OaccDataType, float>)
|
||||
{
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename Problem::PDataType,
|
||||
typename Problem::VDataType>,
|
||||
2>>{};
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<>{};
|
||||
// return
|
||||
// WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
|
||||
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
|
||||
|
||||
@@ -102,4 +102,11 @@ using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed = WarpGemmImpl<
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8>>;
|
||||
|
||||
template <index_t swizzle_factor = 2>
|
||||
using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t>,
|
||||
2,
|
||||
swizzle_factor>>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user