From ae275aa105f49b6532f4724013142de57537e1de Mon Sep 17 00:00:00 2001 From: YC Lin Date: Wed, 9 Apr 2025 03:19:10 +0000 Subject: [PATCH] [GEMM] Refactor block gemm, pipeline, and policy of instruction schedule opt --- .../block_gemm_pipeline_agmem_bgmem_creg.hpp | 377 +++++++++++++++- ...peline_agmem_bgmem_creg_default_policy.hpp | 183 +------- .../ck_tile/99_toy_example/02_gemm/gemm.cpp | 12 +- .../ck_tile/99_toy_example/02_gemm/gemm.hpp | 161 ++++++- .../99_toy_example/02_gemm/grid_gemm.hpp | 67 +-- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 403 ------------------ .../gemm_pipeline_problem.hpp | 190 --------- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 328 -------------- .../instruction_schedule/tile_gemm_shape.hpp | 32 -- .../instruction_schedule/tile_gemm_traits.hpp | 55 --- 10 files changed, 556 insertions(+), 1252 deletions(-) delete mode 100644 example/ck_tile/99_toy_example/02_gemm/instruction_schedule/gemm_pipeline_ag_bg_cr_comp_v3.hpp delete mode 100644 example/ck_tile/99_toy_example/02_gemm/instruction_schedule/gemm_pipeline_problem.hpp delete mode 100644 example/ck_tile/99_toy_example/02_gemm/instruction_schedule/gemm_universal_pipeline_ag_bg_cr_policy.hpp delete mode 100644 example/ck_tile/99_toy_example/02_gemm/instruction_schedule/tile_gemm_shape.hpp delete mode 100644 example/ck_tile/99_toy_example/02_gemm/instruction_schedule/tile_gemm_traits.hpp diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp index 683cabaf83..d6adf81054 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp @@ -3,9 +3,12 @@ #pragma once -#include "ck_tile/core.hpp" #include "block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" + namespace ck_tile { // A Tile Window: global memory @@ -27,6 +30,9 @@ struct BlockGemmPipelineAGmemBGmemCReg CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() { +#if defined(ENABLE_INSTRUCTION_SCH) + return Policy::template GetSmemSize(); +#else return integer_divide_ceil( sizeof(ADataType) * Policy::template MakeALdsBlockDescriptor().get_element_space_size(), @@ -34,8 +40,373 @@ struct BlockGemmPipelineAGmemBGmemCReg 16 + sizeof(BDataType) * Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); +#endif } +#if defined(ENABLE_INSTRUCTION_SCH) + using PipelineImplBase = GemmPipelineAgBgCrImplBase; + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockGemm = remove_cvref_t())>; + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } + static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + template + struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + // Below should be equal to AK1|BK1 + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / + (MPerXDL * NPerXDL * KPerXDL); + + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num : + A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num : + B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / + // sizeof(BDataType) + // ? sizeof(ComputeDataType) / + // sizeof(ADataType) : sizeof(ComputeDataType) + // / sizeof(BDataType); + constexpr auto num_mfma_stage1 = + num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + constexpr auto num_mfma_per_dswrite_a = + (num_mfma_per_issue - num_dswrite_per_issue_a * 2 >= 1) ? 2 : 1; + constexpr auto num_mfma_per_dswrite_b = + (num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_a, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_mfma_per_dswrite_a * num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_mfma_per_dswrite_b * num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "A/B Dram block window should have the same data type as appropriate " + "([A|B]DataType) defined in Problem definition!"); + + // ------------------------------------------------------------------------------------ + // Definitions of all needed tiles + + // A/B tiles in LDS + auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); + + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + // A DRAM tile window for load + // A LDS tile window for store + // A LDS tile for block GEMM + auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + + // B DRAM tile window for load + // B LDS tile window for store + // B LDS tile for block GEMM + auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + + // Block GEMM + auto block_gemm = BlockGemm(); + auto c_block_tile = block_gemm.MakeCBlockTile(); + + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + ABlockTile a_block_tile; + BBlockTile b_block_tile; + + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + + constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, KPerBlock); + + // ----------------------------------------------------------------------------------------- + // Gemm pipeline start + + // prefetch + // global read 0 + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasHotLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 1)); + } + // tail + if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) + { + // Leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + else + { + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + // __builtin_amdgcn_sched_barrier(0); + return c_block_tile; + } + }; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + num_loop, + p_smem); + } + +#else + template CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, @@ -217,9 +588,11 @@ struct BlockGemmPipelineAGmemBGmemCReg iCounter--; } #endif - return c_block_tile; } + +#endif + }; } // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp index 5a342c7dfa..917f86e960 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp @@ -3,13 +3,14 @@ #pragma once +#include "block_gemm_asmem_bsmem_creg.hpp" + #include "ck_tile/core.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" -#include "block_gemm_asmem_bsmem_creg.hpp" #include "config.h" @@ -260,30 +261,17 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy sequence<0, 1>>{}); } - template - CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() - { - return BlockGemmASmemBSmemCReg{}; - } -}; - -#if 0 -// UniversalGemm Policy -struct UniversalGemmPipelineAgBgCrPolicy -{ +#if defined(ENABLE_INSTRUCTION_SCH) static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; static constexpr auto I2 = number<2>{}; - // static constexpr auto ATileAccessPattern = tile_distribution_pattern::thread_raked; - // static constexpr auto BTileAccessPattern = tile_distribution_pattern::thread_raked; - template CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize() { constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize; // 32 = 128 * 64 / 256 + constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize; constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; @@ -353,10 +341,7 @@ struct UniversalGemmPipelineAgBgCrPolicy { using BlockGemm = remove_cvref_t())>; using WG = typename BlockGemm::WarpGemm; - - // constexpr bool TransposeC = Problem::TransposeC; - // using CLayout = typename Problem::CLayout; - using CWarpDstr = typename WG::CWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; // In this case each thread has multiple consecutive elements in // N dimension, however consecutive threads' elements have stride. @@ -374,58 +359,6 @@ struct UniversalGemmPipelineAgBgCrPolicy return Problem::TransposeC; } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() - { - using ADataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K1 = 16 / sizeof(ADataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - // coalesce reading for each blocks - constexpr index_t M1 = kBlockSize / get_warp_size(); - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() - { - using BDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K1 = 16 / sizeof(BDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - // coalesce reading for each blocks - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } - template CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() { @@ -469,104 +402,6 @@ struct UniversalGemmPipelineAgBgCrPolicy return smem_size_a + smem_size_b; } - // 3d + padding - template - CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() - { - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t kKPack = 8; - using ADataType = remove_cvref_t; - - constexpr auto DataTypeSize = sizeof(ADataType); - constexpr auto MLdsLayer = - (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - - constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - - constexpr auto a_lds_block_desc = transform_tensor_descriptor( - a_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform( - make_tuple(number{}, number{})), - make_merge_transform( - make_tuple(number{}, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return a_lds_block_desc; - } - - // 3d + padding - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() - { - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t kKPack = 8; - using BDataType = remove_cvref_t; - - constexpr auto DataTypeSize = sizeof(BDataType); - constexpr auto NLdsLayer = - (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - - constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - - constexpr auto b_lds_block_desc = transform_tensor_descriptor( - b_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform( - make_tuple(number{}, number{})), - make_merge_transform( - make_tuple(number{}, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return b_lds_block_desc; - } - template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { @@ -586,7 +421,13 @@ struct UniversalGemmPipelineAgBgCrPolicy WarpGemm>; return BlockUniversalGemmAsBsCr{}; } -}; +#else + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + return BlockGemmASmemBSmemCReg{}; + } #endif +}; } // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/02_gemm/gemm.cpp b/example/ck_tile/99_toy_example/02_gemm/gemm.cpp index e7f8b3b9f3..5fffcf7e36 100644 --- a/example/ck_tile/99_toy_example/02_gemm/gemm.cpp +++ b/example/ck_tile/99_toy_example/02_gemm/gemm.cpp @@ -1,10 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + #include -#include "ck_tile/host.hpp" -#include "reference_gemm.hpp" - #include "config.h" +#include "ck_tile/host.hpp" #include "gemm.hpp" +#include "reference_gemm.hpp" /* * Toy code of GEMM @@ -179,7 +181,9 @@ int main(int argc, char* argv[]) { // reference gemm ck_tile::HostTensor c_host_ref(c_lengths, c_strides); - reference_basic_gemm(a_host, b_host, c_host_ref); + reference_basic_gemm(a_host, + b_host, + c_host_ref); c_buf.FromDevice(c_host_dev.mData.data()); pass &= ck_tile::check_err(c_host_dev, c_host_ref); std::cout << "valid:" << (pass ? "y" : "n") << std::endl; diff --git a/example/ck_tile/99_toy_example/02_gemm/gemm.hpp b/example/ck_tile/99_toy_example/02_gemm/gemm.hpp index 492b3f1ac0..631d43b25d 100644 --- a/example/ck_tile/99_toy_example/02_gemm/gemm.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/gemm.hpp @@ -4,9 +4,9 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" -#include "ck_tile/core/tensor/tile_distribution.hpp" #include "block_gemm_pipeline_agmem_bgmem_creg.hpp" #include "config.h" @@ -29,7 +29,28 @@ struct GridGemmProblem using CElementFunction = CElementFunction_; }; -#ifndef ENABLE_INSTRUCTION_SCH +#if defined(ENABLE_INSTRUCTION_SCH) +template +struct TileGemmShape +{ + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); + + static constexpr index_t kM = BlockTile::at(number<0>{}); + static constexpr index_t kN = BlockTile::at(number<1>{}); + static constexpr index_t kK = BlockTile::at(number<2>{}); + + static constexpr bool PermuteA = PermuteA_; + static constexpr bool PermuteB = PermuteB_; +}; +#else template struct TileGemmShape { @@ -39,6 +60,91 @@ struct TileGemmShape }; #endif +#if defined(ENABLE_INSTRUCTION_SCH) +template +struct TileGemmTraits +{ + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kPadK = kPadK_; + + // TODO this can't be hardcoded here! Should be in policy! + static constexpr int _VectorSize = 16; + + using ALayout = ALayout_; + using BLayout = BLayout_; + using CLayout = CLayout_; + + static constexpr bool TransposeC = false; +}; + +template +struct TileGemmUniversalTraits +{ + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kPadK = kPadK_; + + static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; + + using ALayout = ALayout_; + using BLayout = BLayout_; + using CLayout = CLayout_; + + static constexpr bool TransposeC = TransposeC_; +}; + +template +struct BlockGemmPipelineProblem +{ + using Traits = remove_cvref_t; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + + using BlockGemmShape = remove_cvref_t; + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); + + static constexpr bool kPadM = Traits::kPadM; + static constexpr bool kPadN = Traits::kPadN; + static constexpr bool kPadK = Traits::kPadK; + + static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; + + static constexpr auto Scheduler = Scheduler_; + static constexpr auto HasHotLoop = HasHotLoop_; + static constexpr auto TailNum = TailNum_; + + static constexpr bool TransposeC = Traits::TransposeC; +}; +#else template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline() { +#if defined(ENABLE_INSTRUCTION_SCH) + // Block GEMM pipeline w/ instruction scheduling + using GemmShape = TileGemmShape, + sequence, + sequence, + PermuteA, + PermuteB>; + + using GemmUniversalTraits = + TileGemmUniversalTraits; + + using BlockGemmPipelineProblem_ = + BlockGemmPipelineProblem; +#else using BlockGemmPipelineProblem_ = BlockGemmPipelineProblem>; +#endif return BlockGemmPipelineAGmemBGmemCReg{}; } -#endif }; using GridGemm = GridGemm; diff --git a/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp b/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp index 1458fd997c..8fb8cdbff7 100644 --- a/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp @@ -1,15 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#ifdef ENABLE_INSTRUCTION_SCH -#include "instruction_schedule/gemm_pipeline_ag_bg_cr_comp_v3.hpp" -#include "instruction_schedule/gemm_pipeline_problem.hpp" -#include "instruction_schedule/gemm_universal_pipeline_ag_bg_cr_policy.hpp" -#include "instruction_schedule/tile_gemm_shape.hpp" -#include "instruction_schedule/tile_gemm_traits.hpp" -#endif - #pragma once namespace ck_tile { @@ -47,8 +38,10 @@ struct GridGemm const auto id_tile = block2tile(id_block); - const auto iM = __builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) * kMPerBlock); - const auto iN = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) * kNPerBlock); + const auto iM = + __builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) * kMPerBlock); + const auto iN = + __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) * kNPerBlock); // A block window auto a_block_window = make_tile_window( @@ -58,62 +51,10 @@ struct GridGemm auto b_block_window = make_tile_window( b_grid, make_tuple(number{}, number{}), {iN, 0}); -#ifndef ENABLE_INSTRUCTION_SCH - // Block GEMM pipeline w/o instruction scheduling constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline(); __shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()]; -#else - // Block GEMM pipeline w/ instruction scheduling - static constexpr index_t M_Tile = 128; - static constexpr index_t N_Tile = 128; - static constexpr index_t K_Tile = 64; - static constexpr index_t M_Warp = 2; - static constexpr index_t N_Warp = 2; - static constexpr index_t K_Warp = 1; - static constexpr index_t M_Warp_Tile = 16; - static constexpr index_t N_Warp_Tile = 16; - static constexpr index_t K_Warp_Tile = 32; - static constexpr bool DoubleSmemBuffer = false; - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - static constexpr bool PermuteA = false; - static constexpr bool PermuteB = false; - static constexpr bool TransposeC = false; - // static constexpr int kBlockPerCu = 1; - // static constexpr index_t TileParitionerGroupNum = 8; - // static constexpr index_t TileParitionerM01 = 4; - - using GemmShape = TileGemmShape, - sequence, - sequence, - PermuteA, - PermuteB>; - - using GemmUniversalTraits = TileGemmUniversalTraits; - - using UniversalGemmProblem = UniversalGemmPipelineProblem; - - constexpr auto block_gemm_pipeline = GemmPipelineAgBgCrCompV3(); - - __shared__ char p_smem_char[block_gemm_pipeline.GetSmemSize()]; -#endif const auto acc_block_tile = block_gemm_pipeline(a_block_window, b_block_window, K / kKPerBlock, diff --git a/example/ck_tile/99_toy_example/02_gemm/instruction_schedule/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/example/ck_tile/99_toy_example/02_gemm/instruction_schedule/gemm_pipeline_ag_bg_cr_comp_v3.hpp deleted file mode 100644 index fc9fc232fe..0000000000 --- a/example/ck_tile/99_toy_example/02_gemm/instruction_schedule/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ /dev/null @@ -1,403 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck_tile/core.hpp" - -#include "gemm_universal_pipeline_ag_bg_cr_policy.hpp" - -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" -#include "ck_tile/host/concat.hpp" - -namespace ck_tile { - -template -struct GemmPipelineAgBgCrCompV3 -{ - using PipelineImplBase = GemmPipelineAgBgCrImplBase; - - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - - static constexpr index_t APackedSize = - ck_tile::numeric_traits>::PackedSize; - static constexpr index_t BPackedSize = - ck_tile::numeric_traits>::PackedSize; - - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - - using BlockGemm = remove_cvref_t())>; - using I0 = number<0>; - using I1 = number<1>; - using I2 = number<2>; - - static constexpr index_t BlockSize = Problem::kBlockSize; - static constexpr index_t MPerBlock = BlockGemmShape::kM; - static constexpr index_t NPerBlock = BlockGemmShape::kN; - static constexpr index_t KPerBlock = BlockGemmShape::kK; - - static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } - static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } - static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } - - static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } - static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } - - static constexpr bool kPadM = Problem::kPadM; - static constexpr bool kPadN = Problem::kPadN; - static constexpr bool kPadK = Problem::kPadK; - - static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; - - static constexpr bool HasHotLoop = Problem::HasHotLoop; - static constexpr auto TailNum = Problem::TailNum; - static constexpr auto Scheduler = Problem::Scheduler; - - [[nodiscard]] CK_TILE_HOST static const std::string GetName() - { - // clang-format off - return concat('_', "pipeline_AgBgCrCompV3", BlockSize, - concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), - concat('x', kPadM, kPadN, kPadK)); - // clang-format on - } - - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - return Policy::template GetSmemSize(); - } - - template - struct PipelineImpl : public PipelineImplBase - { - }; - - template <> - struct PipelineImpl : public PipelineImplBase - { - using Base = PipelineImplBase; - - CK_TILE_DEVICE static constexpr auto HotLoopScheduler() - { - constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; - constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; - constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; - - constexpr index_t WaveSize = 64; - constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); - constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); - - // Below should be equal to AK1|BK1 - constexpr index_t A_LDS_Read_Width = GetSmemPackA(); - constexpr index_t B_LDS_Read_Width = GetSmemPackB(); - - constexpr index_t A_LDS_Write_Width = GetSmemPackA(); - constexpr index_t B_LDS_Write_Width = GetSmemPackB(); - - constexpr index_t A_Buffer_Load_Inst_Num = - MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); - constexpr index_t B_Buffer_Load_Inst_Num = - NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); - - constexpr index_t A_LDS_Write_Inst_Num = - MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); - constexpr index_t B_LDS_Write_Inst_Num = - NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); - - constexpr index_t A_LDS_Read_Inst_Num = - WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); - constexpr index_t B_LDS_Read_Inst_Num = - WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); - - constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / - (BlockSize / WaveSize) / - (MPerXDL * NPerXDL * KPerXDL); - - // A/B split schedule - // compiler is likely to use ds_read2 when instruction width smaller than 16bytes - constexpr auto num_ds_read_inst_a = - A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num - : A_LDS_Read_Inst_Num / 2; - constexpr auto num_ds_read_inst_b = - B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num - : B_LDS_Read_Inst_Num / 2; - - constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num; - constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; - - constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num; - constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num; - - constexpr auto num_mfma_inst = C_MFMA_Inst_Num; - - constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; - constexpr auto ds_read_a_issue_cycle = - A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4; - constexpr auto ds_read_b_issue_cycle = - B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4; - constexpr auto ds_read_a_mfma_rate = - (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); - constexpr auto ds_read_b_mfma_rate = - (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); - - constexpr auto num_dsread_a_mfma = - (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; - constexpr auto num_dsread_b_mfma = - (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; - - // stage 1 - // Separate this part? - // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > - // sizeof(ComputeDataType) / - // sizeof(BDataType) - // ? sizeof(ComputeDataType) / - // sizeof(ADataType) : sizeof(ComputeDataType) - // / sizeof(BDataType); - constexpr auto num_mfma_stage1 = - num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); - constexpr auto num_mfma_per_issue = - num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); - constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; - constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; - constexpr auto num_mfma_per_dswrite_a = - (num_mfma_per_issue - num_dswrite_per_issue_a * 2 >= 1) ? 2 : 1; - constexpr auto num_mfma_per_dswrite_b = - (num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1; - - static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { - ignore = i; - static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { - ignore = idswrite; - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_a, 0); // MFMA - }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier( - 0x008, num_mfma_per_issue - num_mfma_per_dswrite_a * num_dswrite_per_issue_a, 0); // MFMA - }); - static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { - ignore = i; - static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { - ignore = idswrite; - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA - }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier( - 0x008, num_mfma_per_issue - num_mfma_per_dswrite_b * num_dswrite_per_issue_b, 0); // MFMA - }); - - // stage 2 - static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { - if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= - ds_read_a_mfma_rate) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x100, - num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate, - 0); // DS read - } - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - - static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { - if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= - ds_read_b_mfma_rate) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x100, - num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate, - 0); // DS read - } - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - } - - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - index_t num_loop, - void* p_smem) const - { - static_assert( - std::is_same_v> && - std::is_same_v>, - "A/B Dram block window should have the same data type as appropriate " - "([A|B]DataType) defined in Problem definition!"); - - // ------------------------------------------------------------------------------------ - // Definitions of all needed tiles - - // A/B tiles in LDS - auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); - - // Tile distribution for load from lds - constexpr auto a_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); - constexpr auto b_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); - - // A DRAM tile window for load - // A LDS tile window for store - // A LDS tile for block GEMM - auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = - Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); - - // B DRAM tile window for load - // B LDS tile window for store - // B LDS tile for block GEMM - auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = - Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); - - // Block GEMM - auto block_gemm = BlockGemm(); - auto c_block_tile = block_gemm.MakeCBlockTile(); - - using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); - using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); - - using ABlockTile = - decltype(make_static_distributed_tensor(ABlockTileDistr{})); - using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); - - ABlockTile a_block_tile; - BBlockTile b_block_tile; - - using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; - using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; - - constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock); - constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, KPerBlock); - - // ----------------------------------------------------------------------------------------- - // Gemm pipeline start - - // prefetch - // global read 0 - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); - - // initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - - // LDS write 0 - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); - - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); - - block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); - - __builtin_amdgcn_sched_barrier(0); - - // main body - if constexpr(HasHotLoop) - { - index_t i = 0; - do - { - block_sync_lds(); - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); - - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); - - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - block_sync_lds(); - - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); - HotLoopScheduler(); - __builtin_amdgcn_sched_barrier(0); - - i += 1; - } while(i < (num_loop - 1)); - } - // tail - if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) - { - // Leak last MFMA block to epilogue region, cover the potential lds-shuffle - // latency - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - } - else - { - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - block_sync_lds(); - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); - block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - } - // __builtin_amdgcn_sched_barrier(0); - return c_block_tile; - } - }; - - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - index_t num_loop, - void* p_smem) const - { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - a_element_func, - b_dram_block_window_tmp, - b_element_func, - num_loop, - p_smem); - } - - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, - index_t num_loop, - void* p_smem) const - { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, - b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, - num_loop, - p_smem); - } -}; - -} // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/02_gemm/instruction_schedule/gemm_pipeline_problem.hpp b/example/ck_tile/99_toy_example/02_gemm/instruction_schedule/gemm_pipeline_problem.hpp deleted file mode 100644 index 34e8fa18c8..0000000000 --- a/example/ck_tile/99_toy_example/02_gemm/instruction_schedule/gemm_pipeline_problem.hpp +++ /dev/null @@ -1,190 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/host/concat.hpp" - -namespace ck_tile { - -template -struct GemmPipelineProblemBase -{ - using Traits = remove_cvref_t; - - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using ComputeDataType = remove_cvref_t; - - using BlockGemmShape = remove_cvref_t; - - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - - static constexpr bool TransposeC = Traits::TransposeC; - - static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); - - static constexpr bool kPadM = Traits::kPadM; - static constexpr bool kPadN = Traits::kPadN; - static constexpr bool kPadK = Traits::kPadK; - - static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; - - static constexpr auto Scheduler = GemmPipelineScheduler::Default; - static constexpr index_t VectorLoadSize = Traits::_VectorSize; - - CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA() - { - constexpr index_t PackedSize = - ck_tile::numeric_traits>::PackedSize; - if constexpr(std::is_same_v) - { - constexpr index_t pixels_per_thread = - BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize; - return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType) - ? pixels_per_thread - : PackedSize * VectorLoadSize / sizeof(ADataType); - } - else - { - return VectorLoadSize / sizeof(ADataType); - } - } - - CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB() - { - constexpr index_t PackedSize = - ck_tile::numeric_traits>::PackedSize; - if constexpr(std::is_same_v) - { - constexpr index_t pixels_per_thread = - BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize; - return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType) - ? pixels_per_thread - : PackedSize * VectorLoadSize / sizeof(BDataType); - } - else - { - return PackedSize * VectorLoadSize / sizeof(BDataType); - } - } - - CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC() - { - if constexpr(std::is_same_v) - { - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size()); - constexpr index_t M0 = get_warp_size() / N2; - constexpr index_t M1 = BlockGemmShape::kM / M0; - - return std::min(M1, static_cast(VectorLoadSize / sizeof(CDataType))); - } - else - { - constexpr index_t M1 = kBlockSize / get_warp_size(); - constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size()); - constexpr index_t N0 = get_warp_size() / M2; - constexpr index_t N1 = BlockGemmShape::kN / N0; - - return std::min(N1, static_cast(VectorLoadSize / sizeof(CDataType))); - } - } - - static constexpr index_t VectorSizeA = []() { - if constexpr(std::is_same_v) - { - return kPadK ? 1 : GetAlignmentA(); - } - else - { - return kPadM ? 1 : GetAlignmentA(); - } - }(); - - static constexpr index_t VectorSizeB = []() { - if constexpr(std::is_same_v) - { - return kPadN ? 1 : GetAlignmentB(); - } - else - { - return kPadK ? 1 : GetAlignmentB(); - } - }(); - static constexpr index_t VectorSizeC = []() { - if constexpr(std::is_same_v) - { - return kPadN ? 1 : GetAlignmentC(); - } - else - { - return kPadM ? 1 : GetAlignmentC(); - } - }(); -}; - -// Alias for GemmPipelineProblem -template -using GemmPipelineProblem = GemmPipelineProblemBase; - -template -struct UniversalGemmPipelineProblem -{ - using Traits = remove_cvref_t; - - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using ComputeDataType = remove_cvref_t; - - using BlockGemmShape = remove_cvref_t; - - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - - static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); - - static constexpr bool kPadM = Traits::kPadM; - static constexpr bool kPadN = Traits::kPadN; - static constexpr bool kPadK = Traits::kPadK; - - static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; - - static constexpr auto Scheduler = Scheduler_; - static constexpr auto HasHotLoop = HasHotLoop_; - static constexpr auto TailNum = TailNum_; - - static constexpr bool TransposeC = Traits::TransposeC; -}; - -} // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/02_gemm/instruction_schedule/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/instruction_schedule/gemm_universal_pipeline_ag_bg_cr_policy.hpp deleted file mode 100644 index b147e8f178..0000000000 --- a/example/ck_tile/99_toy_example/02_gemm/instruction_schedule/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ /dev/null @@ -1,328 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" -#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" -#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" -#include "ck_tile/ops/common/tensor_layout.hpp" - -namespace ck_tile { - -// UniversalGemm Policy -struct UniversalGemmPipelineAgBgCrPolicy -{ - static constexpr auto I0 = number<0>{}; - static constexpr auto I1 = number<1>{}; - static constexpr auto I2 = number<2>{}; - - template - CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize() - { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize; - constexpr index_t PackedSize = - ck_tile::numeric_traits>::PackedSize; - - // Assume DataType is even! - if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 && - elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 && - PackedSize == 2) - { - return (PackedSize * 32 / sizeof(DataType)); - } - else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 && - elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0) - { - return (PackedSize * 16 / sizeof(DataType)); - } - else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 && - elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0) - { - return (PackedSize * 8 / sizeof(DataType)); - } - else if constexpr(sizeof(DataType) >= PackedSize * 4 && - XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 && - elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0) - { - return (PackedSize * 4 / sizeof(DataType)); - } - else if constexpr(sizeof(DataType) >= PackedSize * 2 && - XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 && - elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0) - { - return (PackedSize * 2 / sizeof(DataType)); - } - else - { - return PackedSize; - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() - { - using ADataType = remove_cvref_t; - using ALayout = remove_cvref_t; - static_assert(std::is_same_v); - - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - - return GetGlobalVectorLoadSize(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() - { - using BDataType = remove_cvref_t; - using BLayout = remove_cvref_t; - static_assert(std::is_same_v); - - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - - return GetGlobalVectorLoadSize(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() - { - using BlockGemm = remove_cvref_t())>; - using WG = typename BlockGemm::WarpGemm; - using CWarpDstr = typename WG::CWarpDstr; - - // In this case each thread has multiple consecutive elements in - // N dimension, however consecutive threads' elements have stride. - constexpr index_t NDimY = CWarpDstr::NDimY; - constexpr auto c_warp_y_lengths = - CWarpDstr{}.get_ys_to_d_descriptor().get_lengths(); - static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane == - c_warp_y_lengths.get(number{})); - return c_warp_y_lengths.get(number{}); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() - { - return Problem::TransposeC; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() - { - using ADataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K1 = 16 / sizeof(ADataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - // coalesce reading for each blocks - constexpr index_t M1 = kBlockSize / get_warp_size(); - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() - { - using BDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K1 = 16 / sizeof(BDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - // coalesce reading for each blocks - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() - { - using BlockGemm = remove_cvref_t())>; - constexpr index_t KPack = BlockGemm::Traits::KPack; - return KPack; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() - { - using BlockGemm = remove_cvref_t())>; - constexpr index_t KPack = BlockGemm::Traits::KPack; - return KPack; - } - - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() - { - constexpr auto a_lds_desc = MakeALdsBlockDescriptor(); - constexpr index_t smem_size_a = integer_least_multiple( - sizeof(typename Problem::ADataType) * a_lds_desc.get_element_space_size(), 16); - return smem_size_a; - } - - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() - { - constexpr auto b_lds_desc = MakeBLdsBlockDescriptor(); - constexpr index_t smem_size_b = integer_least_multiple( - sizeof(typename Problem::BDataType) * b_lds_desc.get_element_space_size(), 16); - return smem_size_b; - } - - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - constexpr index_t smem_size_a = GetSmemSizeA(); - constexpr index_t smem_size_b = GetSmemSizeB(); - - return smem_size_a + smem_size_b; - } - - // 3d + padding - template - CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() - { - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t kKPack = 8; - using ADataType = remove_cvref_t; - - constexpr auto DataTypeSize = sizeof(ADataType); - constexpr auto MLdsLayer = - (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - - constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - - constexpr auto a_lds_block_desc = transform_tensor_descriptor( - a_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform( - make_tuple(number{}, number{})), - make_merge_transform( - make_tuple(number{}, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return a_lds_block_desc; - } - - // 3d + padding - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() - { - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t kKPack = 8; - using BDataType = remove_cvref_t; - - constexpr auto DataTypeSize = sizeof(BDataType); - constexpr auto NLdsLayer = - (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - - constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - - constexpr auto b_lds_block_desc = transform_tensor_descriptor( - b_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform( - make_tuple(number{}, number{})), - make_merge_transform( - make_tuple(number{}, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return b_lds_block_desc; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() - { - using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using WarpGemm = WarpGemmMfmaDispatcher; - using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; - return BlockUniversalGemmAsBsCr{}; - } -}; - -} // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/02_gemm/instruction_schedule/tile_gemm_shape.hpp b/example/ck_tile/99_toy_example/02_gemm/instruction_schedule/tile_gemm_shape.hpp deleted file mode 100644 index 0212c1c9f8..0000000000 --- a/example/ck_tile/99_toy_example/02_gemm/instruction_schedule/tile_gemm_shape.hpp +++ /dev/null @@ -1,32 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/host/concat.hpp" - -namespace ck_tile { - -template -struct TileGemmShape -{ - using BlockTile = remove_cvref_t; - using BlockWarps = remove_cvref_t; - using WarpTile = remove_cvref_t; - - static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); - - static constexpr index_t kM = BlockTile::at(number<0>{}); - static constexpr index_t kN = BlockTile::at(number<1>{}); - static constexpr index_t kK = BlockTile::at(number<2>{}); - - static constexpr bool PermuteA = PermuteA_; - static constexpr bool PermuteB = PermuteB_; -}; - -} // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/02_gemm/instruction_schedule/tile_gemm_traits.hpp b/example/ck_tile/99_toy_example/02_gemm/instruction_schedule/tile_gemm_traits.hpp deleted file mode 100644 index d0e1f60d38..0000000000 --- a/example/ck_tile/99_toy_example/02_gemm/instruction_schedule/tile_gemm_traits.hpp +++ /dev/null @@ -1,55 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -namespace ck_tile { - -template -struct TileGemmTraits -{ - static constexpr bool kPadM = kPadM_; - static constexpr bool kPadN = kPadN_; - static constexpr bool kPadK = kPadK_; - - // TODO this can't be hardcoded here! Should be in policy! - static constexpr int _VectorSize = 16; - - using ALayout = ALayout_; - using BLayout = BLayout_; - using CLayout = CLayout_; - - static constexpr bool TransposeC = false; -}; - -template -struct TileGemmUniversalTraits -{ - static constexpr bool kPadM = kPadM_; - static constexpr bool kPadN = kPadN_; - static constexpr bool kPadK = kPadK_; - - static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; - - using ALayout = ALayout_; - using BLayout = BLayout_; - using CLayout = CLayout_; - - static constexpr bool TransposeC = TransposeC_; -}; - -} // namespace ck_tile