From bb7bc263a31e06a4013a5836ffad0466755196d7 Mon Sep 17 00:00:00 2001 From: rocking Date: Mon, 22 Apr 2024 20:28:49 +0800 Subject: [PATCH] Small refactor (#1246) * Remove kIsFp8 * Extract alias * Fix K, V and corresponding acc type --------- Co-authored-by: Po Yen Chen [ROCm/composable_kernel commit: 43879b89e4269199e445fc3038b6fe1d097cbb06] --- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 13 +++--- .../pipeline/block_fmha_pipeline_problem.hpp | 7 --- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 1 - .../block_fmha_pipeline_qr_ks_vs_async.hpp | 1 - .../block_fmha_pipeline_qr_ks_vs_fp8.hpp | 1 - .../pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 1 - ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 43 +++++++++---------- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 7 +++ 8 files changed, 33 insertions(+), 41 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 822348eb3d..0732fd2ce2 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -27,13 +27,12 @@ struct FmhaFwdKernel static_assert(kBlockPerCu > 0); static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; - using QDataType = ck_tile::remove_cvref_t; - using KDataType = ck_tile::remove_cvref_t; - using VDataType = ck_tile::remove_cvref_t; - using BiasDataType = ck_tile::remove_cvref_t; - using LSEDataType = ck_tile::remove_cvref_t; - using ODataType = ck_tile::remove_cvref_t; - static constexpr bool kIsFp8 = FmhaPipeline::kIsFp8; + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using BiasDataType = ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; using VLayout = ck_tile::remove_cvref_t; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 0dcd366173..9d27b2df68 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -49,13 +49,6 @@ struct BlockFmhaPipelineProblem static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; - static constexpr bool kIsFp8 = - (std::is_same_v || std::is_same_v)&&( - std::is_same_v || - std::is_same_v)&&(std::is_same_v || - std::is_same_v)&&std:: - is_same_v && - std::is_same_v; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 2521d84409..9e239bb916 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -31,7 +31,6 @@ struct BlockFmhaPipelineQRKSVS using VLayout = remove_cvref_t; static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static_assert(kQLoadOnce == Policy::QLoadOnce); - static constexpr bool kIsFp8 = Problem::kIsFp8; static constexpr index_t kBlockSize = Problem::kBlockSize; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index b54b18e111..0573b50d04 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -32,7 +32,6 @@ struct BlockFmhaPipelineQRKSVSAsync using VLayout = remove_cvref_t; static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static_assert(kQLoadOnce == Policy::QLoadOnce); - static constexpr bool kIsFp8 = Problem::kIsFp8; static constexpr index_t kBlockSize = Problem::kBlockSize; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp index 34ff50347b..0e59ee6fe0 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -31,7 +31,6 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 using VLayout = remove_cvref_t; static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static_assert(kQLoadOnce == Policy::QLoadOnce); - static constexpr bool kIsFp8 = Problem::kIsFp8; static constexpr index_t kBlockSize = Problem::kBlockSize; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 985b678e70..677c05769c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -30,7 +30,6 @@ struct BlockFmhaPipelineQSKSVS using VLayout = remove_cvref_t; static constexpr bool kQLoadOnce = false; static_assert(kQLoadOnce == Policy::QLoadOnce); - static constexpr bool kIsFp8 = Problem::kIsFp8; static constexpr index_t kBlockSize = Problem::kBlockSize; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 712c0ca2c9..4fda6f008f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -97,16 +97,15 @@ struct BlockFmhaPipelineQXCustomPolicy { return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; } - else if constexpr(Problem::kIsFp8) + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) { - constexpr index_t swizzle_factor = 4; // TODO: hard coded here - return WarpGemmImpl< - WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< - WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base, - 2, - swizzle_factor>>{}; - } + // TODO: hard coded here. Otherwise, it may incorrect result + constexpr index_t swizzle_factor = 4; + return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution< + swizzle_factor>{}; + } // TODO - bf8_t }(); using BlockGemmPolicy = @@ -221,16 +220,15 @@ struct BlockFmhaPipelineQXCustomPolicy { return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; } - else if constexpr(Problem::kIsFp8) + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) { - constexpr index_t swizzle_factor = 4; // TODO: hard coded here - return WarpGemmImpl< - WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< - WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base, - 2, - swizzle_factor>>{}; - } + // TODO: hard coded here. Otherwise, it may incorrect result + constexpr index_t swizzle_factor = 4; + return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution< + swizzle_factor>{}; + } // TODO - bf8_t }(); using BlockGemmPolicy = @@ -920,12 +918,11 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy>; auto warp_gemm = [&]() { - if constexpr(Problem::kIsFp8) + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) { - return WarpGemmImpl, - 2>>{}; + return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<>{}; // return // WarpGemmImpl>; +template +using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution = + WarpGemmImpl, + 2, + swizzle_factor>>; + } // namespace ck_tile