diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 18980ee0f4..d78f05ea2b 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -220,9 +220,9 @@ int run_flatmm_example(int argc, char* argv[]) std::string b_layout = arg_parser.get_str("b_layout"); if(a_layout == "R" && b_layout == "C") { - + if(data_type == "fp16") - { + { std::cout << "Running with fp16 data type" << std::endl; run_flatmm_example_with_layouts>( argc, argv, Row{}, Col{}, Row{}); @@ -264,7 +264,7 @@ int main(int argc, char* argv[]) { int warp_tile = arg_parser.get_int("warp_tile"); if(warp_tile == 0) - { + { std::cout << "Running with warp tile size 16x16" << std::endl; return !run_flatmm_example(argc, argv); } diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp old mode 100755 new mode 100644 index db647915b8..5f7b517513 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -590,27 +590,27 @@ struct FlatmmKernel operator()( c_block_window, c_block_tile, d_block_window, smem_ptr); } - } + } CK_TILE_DEVICE static void RunFlatmm2(const ADataType* a_ptr, - const BDataType* b_flat_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - void* smem_ptr_ping, - void* smem_ptr_pong, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) + const BDataType* b_flat_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + void* smem_ptr_ping, + void* smem_ptr_pong, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows + // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - + const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); - + // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); const auto& b_flat_block_window = gemm_tile_windows.at(I1); @@ -651,7 +651,8 @@ struct FlatmmKernel if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value) && FlatmmPipeline::DoubleSmemBuffer == false) + is_any_of::value) && + FlatmmPipeline::DoubleSmemBuffer == false) { constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1); RunFlatmm(a_ptr, @@ -667,15 +668,15 @@ struct FlatmmKernel else { RunFlatmm2(a_ptr, - b_flat_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr, - smem_ptr_pong, - kargs, - splitk_batch_offset, - i_m, - i_n); + b_flat_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr, + smem_ptr_pong, + kargs, + splitk_batch_offset, + i_m, + i_n); } } }; diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index dfe91b584f..fd931b81dd 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -51,7 +51,8 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV2 } template - CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_number) + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool, TailNumber tail_number) { if(tail_number == TailNumber::Odd) { @@ -499,7 +500,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV template struct FlatmmPipelineAGmemBGmemCRegV2 : public BaseFlatmmPipelineAGmemBGmemCRegV2 { - using Base = BaseFlatmmPipelineAGmemBGmemCRegV2; + using Base = BaseFlatmmPipelineAGmemBGmemCRegV2; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -510,8 +511,6 @@ struct FlatmmPipelineAGmemBGmemCRegV2 : public BaseFlatmmPipelineAGmemBGmemCRegV using BLayout = remove_cvref_t; using CLayout = remove_cvref_t; - - using BlockFlatmm = remove_cvref_t())>; @@ -585,7 +584,6 @@ struct FlatmmPipelineAGmemBGmemCRegV2 : public BaseFlatmmPipelineAGmemBGmemCRegV static constexpr MfmaConfig GetMfmaConfig() { - // K1 per Mfma = 0.5 cases: mfma_per_wg = 2, dsread_per_wg = 1 if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 32 && std::is_same_v) || @@ -645,7 +643,6 @@ struct FlatmmPipelineAGmemBGmemCRegV2 : public BaseFlatmmPipelineAGmemBGmemCRegV // clang-format on } - static constexpr bool DoubleSmemBuffer = true; static constexpr index_t Preshuffle = Problem::Preshuffle; using Base::UsePersistentKernel; @@ -1024,11 +1021,13 @@ struct FlatmmPipelineAGmemBGmemCRegV2 : public BaseFlatmmPipelineAGmemBGmemCRegV } else { - if constexpr ((A_LDS_Read_Inst_Num / 2 > - A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num)) { + if constexpr((A_LDS_Read_Inst_Num / 2 > + A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num)) + { static_for<0, - A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num, - 1>{}([&](auto i) { + A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - + B_Buffer_Load_Inst_Num, + 1>{}([&](auto i) { ignore = i; __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 6ae5493f34..40ff34c977 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -339,36 +339,36 @@ struct UniversalFlatmmPipelineAgBgCrPolicy } } } - + template CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution() { using ADataType = remove_cvref_t; - //using ALayout = remove_cvref_t; + // using ALayout = remove_cvref_t; constexpr index_t BlockSize = Problem::kBlockSize; // constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t K1 = 16 / sizeof(ADataType); - constexpr index_t K0 = KPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - constexpr index_t M1 = BlockSize / get_warp_size(); - static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); - static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); - // constexpr index_t M0 = MPerBlock / (M2 * M1); - // static_assert(M0 * M1 * M2 == MPerBlock, - // "Incorrect M0, M2, M1 configuration! " - // "M0, M1, M2 must cover whole MPerBlock!"); + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + constexpr index_t M1 = BlockSize / get_warp_size(); + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + // constexpr index_t M0 = MPerBlock / (M2 * M1); + // static_assert(M0 * M1 * M2 == MPerBlock, + // "Incorrect M0, M2, M1 configuration! " + // "M0, M1, M2 must cover whole MPerBlock!"); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<2>, - sequence<1>>{}); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<2>, + sequence<1>>{}); } template CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution() diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index f21136d2a8..30bea193b7 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -15,9 +15,9 @@ #include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp" -#include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp" @@ -29,14 +29,14 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 4fe3e4d55c..f5dbfd0cc3 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -13,9 +13,9 @@ #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp" @@ -44,12 +44,12 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v3.hpp" -#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" -#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index f030679e60..a4a6f9a9cb 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -18,31 +18,31 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution() { using ADataType = remove_cvref_t; - //using ALayout = remove_cvref_t; + // using ALayout = remove_cvref_t; constexpr index_t BlockSize = Problem::kBlockSize; // constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t K1 = 16 / sizeof(ADataType); - constexpr index_t K0 = KPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - constexpr index_t M1 = BlockSize / get_warp_size(); - static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); - static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); - // constexpr index_t M0 = MPerBlock / (M2 * M1); - // static_assert(M0 * M1 * M2 == MPerBlock, - // "Incorrect M0, M2, M1 configuration! " - // "M0, M1, M2 must cover whole MPerBlock!"); + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + constexpr index_t M1 = BlockSize / get_warp_size(); + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + // constexpr index_t M0 = MPerBlock / (M2 * M1); + // static_assert(M0 * M1 * M2 == MPerBlock, + // "Incorrect M0, M2, M1 configuration! " + // "M0, M1, M2 must cover whole MPerBlock!"); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<2>, - sequence<1>>{}); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<2>, + sequence<1>>{}); } // 3d + padding template diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp index 9c9f819aa0..84e47483de 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp @@ -33,7 +33,6 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV1 } }; - template struct WeightPreshufflePipelineAGmemBGmemCRegV1 : public BaseWeightPreshufflePipelineAGmemBGmemCRegV1 @@ -73,7 +72,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; - static constexpr index_t kLdsAlignmentInBytes = Problem::VectorLoadSize/sizeof(ADataType); + static constexpr index_t kLdsAlignmentInBytes = Problem::VectorLoadSize / sizeof(ADataType); static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; static constexpr auto I0 = number<0>(); diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index 5d644a18d8..2cb1e22ea0 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -28,7 +28,8 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2 } template - CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_number) + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool, TailNumber tail_number) { if(tail_number == TailNumber::Odd) { @@ -42,9 +43,10 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2 }; template -struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePipelineAGmemBGmemCRegV2 +struct WeightPreshufflePipelineAGmemBGmemCRegV2 + : public BaseWeightPreshufflePipelineAGmemBGmemCRegV2 { - using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -79,7 +81,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; - static constexpr index_t kLdsAlignmentInBytes = Problem::VectorLoadSize/sizeof(ADataType); + static constexpr index_t kLdsAlignmentInBytes = Problem::VectorLoadSize / sizeof(ADataType); static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; static constexpr auto I0 = number<0>(); @@ -105,18 +107,14 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; - static constexpr index_t K1 = 16 / sizeof(ADataType); - static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1; - static constexpr auto TailNum = Problem::TailNum; + static constexpr index_t K1 = 16 / sizeof(ADataType); + static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1; + static constexpr auto TailNum = Problem::TailNum; static constexpr auto warp_m = WarpTile::at(idxM); static constexpr auto warp_n = WarpTile::at(idxN); static constexpr auto warp_k = WarpTile::at(idxK); - - - - [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off @@ -129,7 +127,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip // clang-format on } - static constexpr bool DoubleSmemBuffer = true; static constexpr index_t Preshuffle = Problem::Preshuffle; using Base::UsePersistentKernel; @@ -508,11 +505,13 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip } else { - if constexpr ((A_LDS_Read_Inst_Num / 2 > - A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num)) { + if constexpr((A_LDS_Read_Inst_Num / 2 > + A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num)) + { static_for<0, - A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num, - 1>{}([&](auto i) { + A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - + B_Buffer_Load_Inst_Num, + 1>{}([&](auto i) { ignore = i; __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA @@ -587,7 +586,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip auto a_lds_block_pong = make_tensor_view(p_a_lds_pong, a_lds_block_desc); -// A DRAM tile window for load + // A DRAM tile window for load auto a_copy_dram_window = make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), @@ -606,8 +605,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip {0, 0}, PipelinePolicy::template MakeADramTileDistribution()); - - // ping-pong window for A LDS auto a_warp_window_ping_tmp = make_tile_window(a_lds_block_ping, @@ -680,12 +677,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip NIterPerWarp> b_warp_tensor_pong; -// Prefetch A0 + // Prefetch A0 auto a_block_tile = load_tile(a_copy_dram_window); // move A window to next k move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - // prefetch B static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { @@ -700,18 +696,17 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip // move B window to next flat K move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); -// Prefill A0 + // Prefill A0 auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window_ping, a_block_tile_tmp); __builtin_amdgcn_sched_barrier(0); -// Prefetch A1 + // Prefetch A1 a_block_tile = load_tile(a_copy_dram_window); // move A window to next k move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -734,8 +729,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip }); __builtin_amdgcn_sched_barrier(0); - - index_t iCounter = (num_loop - 1) / 2; while(iCounter > 0) { @@ -783,7 +776,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); - __builtin_amdgcn_sched_barrier(0x7F6); }); // preload next A from lds @@ -859,7 +851,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); - __builtin_amdgcn_sched_barrier(0x7F6); }); // preload next A from lds @@ -896,7 +887,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip // tail if constexpr(TailNum == TailNumber::Even) { -// __builtin_amdgcn_sched_barrier(0); + // __builtin_amdgcn_sched_barrier(0); // prefetch B(loopK) static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { @@ -936,7 +927,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); - __builtin_amdgcn_sched_barrier(0x7F6); }); // preload next A from lds @@ -1028,7 +1018,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); - __builtin_amdgcn_sched_barrier(0x7F6); }); // preload next A from lds @@ -1050,7 +1039,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip }); } - return c_block_tile; } @@ -1071,5 +1059,4 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 : public BaseWeightPreshufflePip } }; - } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v3.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v3.hpp index 318f3063c2..9ac45b9a9a 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v3.hpp @@ -28,7 +28,8 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV3 } template - CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_number) + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool, TailNumber tail_number) { if(tail_number == TailNumber::Odd) { @@ -42,9 +43,10 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV3 }; template -struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePipelineAGmemBGmemCRegV3 +struct WeightPreshufflePipelineAGmemBGmemCRegV3 + : public BaseWeightPreshufflePipelineAGmemBGmemCRegV3 { - using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV3; + using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV3; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -118,10 +120,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip static constexpr auto warp_n = WarpTile::at(idxN); static constexpr auto warp_k = WarpTile::at(idxK); - - - - [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off @@ -135,7 +133,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip // clang-format on } - static constexpr bool DoubleSmemBuffer = true; static constexpr index_t Preshuffle = Problem::Preshuffle; using Base::UsePersistentKernel; @@ -514,11 +511,13 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip } else { - if constexpr ((A_LDS_Read_Inst_Num / 2 > - A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num)) { + if constexpr((A_LDS_Read_Inst_Num / 2 > + A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num)) + { static_for<0, - A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num, - 1>{}([&](auto i) { + A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - + B_Buffer_Load_Inst_Num, + 1>{}([&](auto i) { ignore = i; __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA @@ -552,7 +551,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip } } - template CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, @@ -570,7 +568,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); // static assert that warptile is 16x16 and not 32x32 - static_assert(WG::kM == 16 && WG::kN == 16, "For pipeline_AGmemBGmemCRegV3, WarpTile must be 16x16, not 32x32"); + static_assert(WG::kM == 16 && WG::kN == 16, + "For pipeline_AGmemBGmemCRegV3, WarpTile must be 16x16, not 32x32"); constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; const index_t iMWarp = get_warp_id() / NWarp; @@ -596,7 +595,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip auto a_lds_block_pong = make_tensor_view(p_a_lds_pong, a_lds_block_desc); -// A DRAM tile window for load + // A DRAM tile window for load auto a_copy_dram_window_tmp = make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), @@ -636,8 +635,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip move_tile_window(a_copy_lds_window_pong(AIter), {AIter * ACopyPerLoadM, 0}); }); - - // ping-pong window for A LDS auto a_warp_window_ping_tmp = make_tile_window(a_lds_block_ping, @@ -710,7 +707,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip NIterPerWarp> b_warp_tensor_pong; -// Prefetch A0 + // Prefetch A0 statically_indexed_array{}))), ACopyLoadNum> a_block_tile; @@ -733,7 +730,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip // move B window to next flat K move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); -// Prefill A0 + // Prefill A0 static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { store_tile(a_copy_lds_window_ping(AIter), @@ -742,7 +739,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip __builtin_amdgcn_sched_barrier(0); -// Prefetch A1 + // Prefetch A1 static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { a_block_tile(AIter) = load_tile(a_copy_dram_window(AIter)); @@ -771,13 +768,10 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip }); __builtin_amdgcn_sched_barrier(0); - - index_t iCounter = (num_loop - 1) / 2; while(iCounter > 0) { - // GEMM 2i static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { @@ -801,7 +795,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); - // prefetch B(2i+1) constexpr auto curMNIter = mIter * NIterPerWarp + nIter; if constexpr((curMNIter < NIterPerWarp * BLoadGap) && @@ -872,7 +865,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip // Next K - // GEMM 2i+1 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { @@ -895,7 +887,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); - // prefetch B(2i+2) constexpr auto curMNIter = mIter * NIterPerWarp + nIter; if constexpr((curMNIter < NIterPerWarp * BLoadGap) && @@ -1153,7 +1144,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip }); } - return c_block_tile; } @@ -1174,5 +1164,4 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 : public BaseWeightPreshufflePip } }; - } // namespace ck_tile